From 16d51b71a5a35914f3b7372ad5c6fb884c4394b5 Mon Sep 17 00:00:00 2001 From: Oliver Date: Sun, 2 Jul 2023 17:53:56 -0400 Subject: [PATCH] Initial stab, borrowing stuff from the JITD compiler --- .gitignore | 3 + astral/src/com/astraldb/Astral.scala | 9 + astral/src/com/astraldb/ast/AST.scala | 32 ++ astral/src/com/astraldb/spec/Definition.scala | 10 + astral/src/com/astraldb/spec/Expression.scala | 178 +++++++++++ astral/src/com/astraldb/spec/Field.scala | 12 + .../astraldb/spec/FunctionDefinition.scala | 36 +++ astral/src/com/astraldb/spec/Node.scala | 12 + astral/src/com/astraldb/spec/Pattern.scala | 223 +++++++++++++ astral/src/com/astraldb/spec/Statement.scala | 128 ++++++++ astral/src/com/astraldb/spec/Type.scala | 37 +++ .../typecheck/FunctionSignature.scala | 33 ++ .../com/astraldb/typecheck/Typechecker.scala | 295 ++++++++++++++++++ build.sc | 30 ++ 14 files changed, 1038 insertions(+) create mode 100644 .gitignore create mode 100644 astral/src/com/astraldb/Astral.scala create mode 100644 astral/src/com/astraldb/ast/AST.scala create mode 100755 astral/src/com/astraldb/spec/Definition.scala create mode 100755 astral/src/com/astraldb/spec/Expression.scala create mode 100755 astral/src/com/astraldb/spec/Field.scala create mode 100755 astral/src/com/astraldb/spec/FunctionDefinition.scala create mode 100755 astral/src/com/astraldb/spec/Node.scala create mode 100644 astral/src/com/astraldb/spec/Pattern.scala create mode 100755 astral/src/com/astraldb/spec/Statement.scala create mode 100755 astral/src/com/astraldb/spec/Type.scala create mode 100755 astral/src/com/astraldb/typecheck/FunctionSignature.scala create mode 100755 astral/src/com/astraldb/typecheck/Typechecker.scala create mode 100644 build.sc diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..db529cd --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.bloop +.metals +/out diff --git a/astral/src/com/astraldb/Astral.scala b/astral/src/com/astraldb/Astral.scala new file mode 100644 index 0000000..eed1b87 --- /dev/null +++ b/astral/src/com/astraldb/Astral.scala @@ -0,0 +1,9 @@ +package com.astraldb + +object Astral +{ + def main(args: Array[String]): Unit = + { + println("Hello!") + } +} \ No newline at end of file diff --git a/astral/src/com/astraldb/ast/AST.scala b/astral/src/com/astraldb/ast/AST.scala new file mode 100644 index 0000000..4f408b2 --- /dev/null +++ b/astral/src/com/astraldb/ast/AST.scala @@ -0,0 +1,32 @@ +package com.astraldb.ast + +import com.astraldb.spec.Constant + +sealed trait AST +{ + def children: Seq[AST] + + def find[T](f: AST => Option[T]): Option[T] = + { + f(this).orElse { + children.foldLeft(None:Option[T]){ (ret, child) => + if(ret.isEmpty) { child.find(f) } + else { ret } + } + } + } +} + +case class AList(elements: Seq[AST]) extends AST +{ + def children: Seq[AST] = elements +} +case class ANode(label: String, elements: Seq[AST]) extends AST +{ + def children: Seq[AST] = elements +} +case class Leaf(value: Constant) extends AST +{ + def children: Seq[AST] = Seq.empty +} + diff --git a/astral/src/com/astraldb/spec/Definition.scala b/astral/src/com/astraldb/spec/Definition.scala new file mode 100755 index 0000000..efb3482 --- /dev/null +++ b/astral/src/com/astraldb/spec/Definition.scala @@ -0,0 +1,10 @@ +package com.astraldb.spec; + +import com.astraldb.typecheck._ + +case class Definition( + nodes:Seq[Node], + // rules:Seq[Rule] +) { +} + diff --git a/astral/src/com/astraldb/spec/Expression.scala b/astral/src/com/astraldb/spec/Expression.scala new file mode 100755 index 0000000..beeed3b --- /dev/null +++ b/astral/src/com/astraldb/spec/Expression.scala @@ -0,0 +1,178 @@ +package com.astraldb.spec + + +object CmpTypes extends Enumeration { + type T = Value + val Eq, Neq, Lt, Lte, Gt, Gte = Value + + def opString(op:T):String = + { + op match { + case Eq => "==" + case Neq => "!=" + case Lt => "<" + case Lte => "<=" + case Gt => ">" + case Gte => ">=" + } + } +} + +object ArithTypes extends Enumeration { + type T = Value + val Add, Sub, Mul, Div, And, Or = Value + + def opString(op:T):String = + { + op match { + case Add => "+" + case Sub => "-" + case Mul => "*" + case Div => "/" + case And => "&&" + case Or => "||" + } + } +} + +sealed abstract class Expression +{ + def eq(other:Expression) = Cmp(CmpTypes.Eq, this, other) + def neq(other:Expression) = Cmp(CmpTypes.Neq, this, other) + def lt(other:Expression) = Cmp(CmpTypes.Lt, this, other) + def lte(other:Expression) = Cmp(CmpTypes.Lte, this, other) + def gt(other:Expression) = Cmp(CmpTypes.Gt, this, other) + def gte(other:Expression) = Cmp(CmpTypes.Gte, this, other) + + def and(other:Expression) = + { + (this, other) match { + case (BoolConstant(true), _) => other + case (BoolConstant(false), _) => this + case (_, BoolConstant(true)) => this + case (_, BoolConstant(false)) => other + case _ => Arith(ArithTypes.And, this, other) + } + } + def or(other:Expression) = + { + (this, other) match { + case (BoolConstant(false), _) => other + case (BoolConstant(true), _) => this + case (_, BoolConstant(false)) => this + case (_, BoolConstant(true)) => other + case _ => Arith(ArithTypes.Or, this, other) + } + } + + def plus(other:Expression) = Arith(ArithTypes.Add, this, other) + def minus(other:Expression) = Arith(ArithTypes.Sub, this, other) + def times(other:Expression) = Arith(ArithTypes.Mul, this, other) + def dividedBy(other:Expression) = Arith(ArithTypes.Div, this, other) + + def get(field:String) = StructSubscript(this, field) + def get(index:Int) = ArraySubscript(this, index) + + def disassemble: Seq[Expression] + def reassemble(in: Seq[Expression]): Expression + def rebuild(fn:(Expression => Expression)): Expression = + reassemble(disassemble.map { fn(_) }) +} + +sealed abstract class Constant(val t:Type) extends Expression +{ + def disassemble = Seq[Expression]() + def reassemble(in: Seq[Expression]): Expression = this +} +case class IntConstant(i:Integer) extends Constant(TInt()) +{ + override def toString:String = i.toString +} +case class FloatConstant(f:Double) extends Constant(TFloat()) +{ + override def toString:String = f.toString +} +case class BoolConstant(b:Boolean) extends Constant(TBool()) +{ + override def toString:String = if(b){ "true" } else { "false" } +} + +case class Var(v:String) extends Expression +{ + def disassemble = Seq[Expression]() + def reassemble(in: Seq[Expression]): Expression = this + override def toString = v.toString +} +case class Cmp(t: CmpTypes.T, lhs:Expression, rhs:Expression) extends Expression +{ + def disassemble = Seq[Expression](lhs, rhs) + def reassemble(in: Seq[Expression]): Expression = Cmp(t, in(0), in(1)) + override def toString = s"($lhs) ${CmpTypes.opString(t)} ($rhs)" +} +case class Arith(t: ArithTypes.T, lhs:Expression, rhs:Expression) extends Expression +{ + def disassemble = Seq[Expression](lhs, rhs) + def reassemble(in: Seq[Expression]): Expression = Arith(t, in(0), in(1)) + override def toString = s"($lhs) ${ArithTypes.opString(t)} ($rhs)" +} +case class FunctionalIfThenElse(condition:Expression, ifTrue:Expression, ifFalse:Expression) extends Expression +{ + def disassemble = Seq[Expression](condition, ifTrue, ifFalse) + def reassemble(in: Seq[Expression]): Expression = FunctionalIfThenElse(in(0), in(1), in(2)) + override def toString = s"($condition) ? ($ifTrue) : ($ifFalse)" +} +case class ArraySubscript(target:Expression, index:Integer) extends Expression +{ + def disassemble = Seq[Expression](target) + def reassemble(in: Seq[Expression]): Expression = ArraySubscript(in(0), index) + override def toString = s"$target[$index]" +} +case class StructSubscript(target:Expression, field:String) extends Expression +{ + def disassemble = Seq[Expression](target) + def reassemble(in: Seq[Expression]): Expression = StructSubscript(in(0), field) + override def toString = s"$target.$field" +} +case class NodeSubscript(target:Expression, field:String) extends Expression +{ + def disassemble = Seq[Expression](target) + def reassemble(in: Seq[Expression]): Expression = NodeSubscript(in(0), field) + override def toString = s"$target->$field" +} +case class NodeCast(nodeType: String,target:Expression, field:String) extends Expression +{ + def disassemble = Seq[Expression](target) + def reassemble(in: Seq[Expression]): Expression = NodeSubscript(in(0), field) + override def toString = s"($nodeType *)$target->$field" +} +case class FunctionCall(name:String, args:Seq[Expression]) extends Expression +{ + def disassemble = args + def reassemble(in: Seq[Expression]): Expression = FunctionCall(name, in) + override def toString = s"$name(${args.mkString(", ")})" +} +case class WrapNode(target: Expression) extends Expression +{ + def disassemble = Seq(target) + def reassemble(in: Seq[Expression]): Expression = WrapNode(in(0)) + override def toString = s"wrap ${target.toString}" +} +case class UnWrapHandle(target: Expression) extends Expression +{ + def disassemble = Seq(target) + def reassemble(in: Seq[Expression]): Expression = UnWrapHandle(in(0)) + override def toString = s"unwraphandleref ${target.toString}" +} +case class WrapNodeRef(target: Expression) extends Expression +{ + def disassemble = Seq(target) + def reassemble(in: Seq[Expression]): Expression = WrapNodeRef(in(0)) + override def toString = s"wrapnoderef ${target.toString}" +} + +case class MakeNode(nodeType: String, fields: Seq[Expression]) extends Expression +{ + def disassemble = fields + def reassemble(in: Seq[Expression]): Expression = MakeNode(nodeType, in) + override def toString = s"allocate ${nodeType}(${fields.mkString(",")})" +} \ No newline at end of file diff --git a/astral/src/com/astraldb/spec/Field.scala b/astral/src/com/astraldb/spec/Field.scala new file mode 100755 index 0000000..40252bd --- /dev/null +++ b/astral/src/com/astraldb/spec/Field.scala @@ -0,0 +1,12 @@ +package com.astraldb.spec + +import scala.language.implicitConversions + +case class Field(name:String, t:Type) +{ + override def toString = s"$name:${Type.toString(t)}" +} + +object FieldConversions { + implicit def tuple2Field(t:(String, Type)): Field = Field(t._1, t._2) +} diff --git a/astral/src/com/astraldb/spec/FunctionDefinition.scala b/astral/src/com/astraldb/spec/FunctionDefinition.scala new file mode 100755 index 0000000..2127e24 --- /dev/null +++ b/astral/src/com/astraldb/spec/FunctionDefinition.scala @@ -0,0 +1,36 @@ +package com.astraldb.spec + +import com.astraldb.typecheck.FunctionSignature + +object FunctionArgType extends Enumeration +{ + type T = Value + val Input, OutputRef, ConstInputRef = Value + + def isConst(t: T): Boolean = + t match { + case ConstInputRef => true + case Input | OutputRef => false + } + def isByRef(t: T): Boolean = + t match { + case ConstInputRef | OutputRef => true + case Input => false + } +} + +case class FunctionDefinition( + name: String, + ret: Option[Type], + args: Seq[(String, Type, FunctionArgType.T)], + body: Statement +) +{ + def signature = + ret match { + case Some(t) => + FunctionSignature(name, args.map { _._2 }, t) + case None => + FunctionSignature(name, args.map { _._2 }) + } +} \ No newline at end of file diff --git a/astral/src/com/astraldb/spec/Node.scala b/astral/src/com/astraldb/spec/Node.scala new file mode 100755 index 0000000..fc32da3 --- /dev/null +++ b/astral/src/com/astraldb/spec/Node.scala @@ -0,0 +1,12 @@ +package com.astraldb.spec; + +case class Node(val name:String, val fields:Seq[Field]) +{ + def renderName = name+"Node" + def enumName = "JITD_NODE_"+name + def scope = fields.map { f => f.name -> f.t }.toMap + + override def toString = + name + "(" + fields.map { _.toString }.mkString(", ") + ")" + +} \ No newline at end of file diff --git a/astral/src/com/astraldb/spec/Pattern.scala b/astral/src/com/astraldb/spec/Pattern.scala new file mode 100644 index 0000000..d9a2d49 --- /dev/null +++ b/astral/src/com/astraldb/spec/Pattern.scala @@ -0,0 +1,223 @@ +package com.astraldb.spec + +import com.astraldb.ast._ +import com.astraldb.spec.MatchPattern.Scope + +/** + * A rule for matching patterns + */ +sealed trait MatchPattern +{ + /** + * Apply the match pattern to a specified abstract syntax tree + * + * Return the updated scope, or None if the pattern does not match + */ + def apply(node: AST, scope: Scope): Option[Scope] +} +object MatchPattern +{ + type Scope = Map[String, AST] +} + +/** + * Boolean logic over match patterns + */ +sealed trait BooleanMatchPattern extends MatchPattern + + +/** + * Boolean not. + */ +case class MatchNot(child: MatchPattern) extends BooleanMatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + child(node, scope) match { + case None => Some(Map.empty) + case Some(_) => None + } +} + +/** + * Boolean and + */ +case class MatchAnd(children: Seq[MatchPattern]) extends BooleanMatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + children.foldLeft(Some(scope):Option[Scope]){ (scope, child) => + scope.flatMap { child(node, _) } + } +} + +/** + * Boolean or + */ +case class MatchOr(children: Seq[MatchPattern]) extends BooleanMatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + { + children.foldLeft(None: Option[Scope]) { (retScope, child) => + if(retScope.isDefined) { retScope } + else { child(node, scope) } + } + } +} + +/** + * Match a node type. + * + * Return a match if the current AST element is a node of the labeled type, + * and if all child matchers match as well. + */ +case class MatchNode(nodeLabel: String, children: Seq[MatchPattern]) extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + node match { + case ANode(label, elements) if label == nodeLabel + && elements.size == children.size => + children.zip(elements).foldLeft(Some(scope):Option[Scope]) { case (scope, (child, element)) => + scope.flatMap { child(element, _) } + } + case _ => None + } +} + +/** + * Match a node, solely based on its type + * + * Return a match if the current AST element is a node of the labeled type. + * Ignore the children. + */ +case class MatchNodeType(nodeLabel: String) extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + node match { + case ANode(label, elements) if label == nodeLabel => Some(scope) + case _ => None + } + +} + +/** + * Match a node exactly. + * + * Return a match if the current AST element is identical to the AST provided + * in the pattern + */ +case class MatchExact(pattern: AST) extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + if(node == pattern){ Some(scope) } + else { None } +} + +/** + * Match a descendant identified by a provided path of indices + * + * The provided path specifies descendants by their positions. The 0th + * index is the first child of a Node or a List AST. The matcher returns + * a match if the specified path exists, and if the provided pattern matches + * the element at the specified position. + */ +case class MatchPath(path: Seq[Int], pattern: MatchPattern) extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + path.foldLeft( Some(node):Option[AST] ) { (target, idx) => + target match { + case Some(ANode(_, elements)) if elements.size > idx => Some(elements(idx)) + case Some(AList(elements)) if elements.size > idx => Some(elements(idx)) + case _ => None + } + } match { + case None => None + case Some(target) => pattern(target, scope) + } +} + +/** + * Match some descendent. + * + * The pattern is applied to all descendents of the current node. The matcher + * returns the first match it finds. + */ +case class MatchRecursive(pattern: MatchPattern) extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + node.find { pattern(_, scope) } +} + +/** + * Bind the current node to the scope + * + * The current node is bound to the scope and the provided pattern is applied + * as normal. + */ +case class MatchBind(symbol: String, pattern: MatchPattern) extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + pattern(node, scope ++ Map(symbol -> node)) +} + +/** + * Always return a match + */ +case object MatchAny extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + Some(scope) +} + +/** + * Never return a match + */ +case object MatchNone extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + None +} + +/** + * Match the current node against an element of the scope. + * + * Returns a match if the provided symbol exists in the scope and the current + * node is exactly equal to the symbol's value in the scope. + */ +case class MatchLookup(symbol: String) extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + scope.get(symbol).flatMap { scopeNode => + if(node == scopeNode){ Some(scope) } else { None } + } +} + +/** + * Apply a matcher to an element of the scope + * + * Returns a match if the provided symbol exists in the scope and the provided + * pattern exactly matches the symbol's value in the scope. + */ +case class MatchScope(symbol: String, pattern: MatchPattern) extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + scope.get(symbol).flatMap { scopeNode => + pattern(scopeNode, scope) + } +} + +/** + * Apply a matcher to every element of a list + * + * Returns a match if the current element is a list and every element of the + * list matches the provided pattern + */ +case class MatchList(pattern: MatchPattern) extends MatchPattern +{ + def apply(node: AST, scope: Scope): Option[Scope] = + node match { + case AList(elements) => + elements.foldLeft( Some(scope): Option[Scope] ){ (scope, element) => + scope.flatMap { pattern(element, _) } + } + case _ => None + } +} diff --git a/astral/src/com/astraldb/spec/Statement.scala b/astral/src/com/astraldb/spec/Statement.scala new file mode 100755 index 0000000..0403a40 --- /dev/null +++ b/astral/src/com/astraldb/spec/Statement.scala @@ -0,0 +1,128 @@ +package com.astraldb.spec + +import scala.language.implicitConversions + +object StatementConversions { + implicit def seq2block(s:Seq[Statement]): Block = Block(s) +} + +sealed abstract class Statement +{ + def disasssembleStatement: Seq[Statement] + def reassembleStatement(in: Seq[Statement]): Statement + def rebuildStatement(fn: Statement => Statement): Statement = + reassembleStatement(disasssembleStatement.map { fn(_) }) + def disasssembleExpression: Seq[Expression] + def reassembleExpression(in: Seq[Expression]): Statement + def rebuildExpression(fn: Expression => Expression): Statement = + reassembleExpression(disasssembleExpression.map { fn(_) }) + .rebuildStatement { _.rebuildExpression(fn) } + + def toString(prefix: String): String + override def toString: String = toString("") + + def blockSeq: Seq[Statement] = Seq(this) + def ++(other: Statement): Block = Block(this.blockSeq ++ other.blockSeq) +} + +case class IfThenElse(condition:Expression, ifTrue:Statement, ifFalse:Statement = Block(Seq())) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq(ifTrue, ifFalse) + def reassembleStatement(in: Seq[Statement]): Statement = IfThenElse(condition, in(0), in(1)) + def disasssembleExpression: Seq[Expression] = Seq(condition) + def reassembleExpression(in: Seq[Expression]): Statement = IfThenElse(in(0), ifTrue, ifFalse) + def toString(prefix: String) = s"${prefix}if($condition)\n"+ifTrue.toString(prefix+" ")+s"\n${prefix}else\n"+ifFalse.toString(prefix+" ") +} +case class Declare(name:String, t:Option[Type], v:Expression) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq() + def reassembleStatement(in: Seq[Statement]): Statement = this + def disasssembleExpression: Seq[Expression] = Seq(v) + def reassembleExpression(in: Seq[Expression]): Statement = Declare(name, t, in(0)) + def toString(prefix: String) = s"${prefix}var $name${t.map { x => ":"+Type.toString(x)}.getOrElse("")} = $v" +} + +case class Assign(name:String, v:Expression, atomic:Boolean = false) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq() + def reassembleStatement(in: Seq[Statement]): Statement = this + def disasssembleExpression: Seq[Expression] = Seq(v) + def reassembleExpression(in: Seq[Expression]): Statement = Assign(name, in(0), atomic) + def toString(prefix: String) = s"${prefix}$name = $v" +} + +case class ExtractNode(name:String, v:Expression, nodeHandlers: Seq[(String, Statement)], onFail: Statement) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq(onFail) ++ nodeHandlers.map { _._2 } + def reassembleStatement(in: Seq[Statement]): Statement = ExtractNode(name, v, nodeHandlers.zip(in.tail).map { case ((name, _), handler) => (name, handler) }, in(0)) + def disasssembleExpression: Seq[Expression] = Seq(v) + def reassembleExpression(in: Seq[Expression]): Statement = ExtractNode(name, in(0), nodeHandlers, onFail) + def toString(prefix: String) = s"${prefix}extract $v into $name { \n"+nodeHandlers.map { + case (nodeType, onMatch) => s"${prefix} case $nodeType -> \n${onMatch.toString(prefix+" ")}\n" + }.mkString+s"\n${prefix} else -> \n${onFail.toString(prefix+" ")}\n${prefix}}" +} +case class Block(statements:Seq[Statement]) extends Statement +{ + def disasssembleStatement: Seq[Statement] = statements + def reassembleStatement(in: Seq[Statement]): Statement = Block(in) + def disasssembleExpression: Seq[Expression] = Seq() + def reassembleExpression(in: Seq[Expression]): Statement = this + def toString(prefix: String) = (Seq(prefix+"{")++statements.map{ _.toString(prefix+" ") }++Seq(prefix+"}")).mkString("\n") + override def blockSeq = statements +} +case class ForEach(loopvar:String, over:Expression, body:Statement) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq(body) + def reassembleStatement(in: Seq[Statement]): Statement = ForEach(loopvar, over, in(0)) + def disasssembleExpression: Seq[Expression] = Seq(over) + def reassembleExpression(in: Seq[Expression]): Statement = ForEach(loopvar, in(0), body) + def toString(prefix: String) = s"${prefix}for($loopvar in $over)\n${body.toString(prefix+" ")}" +} +case class Void(v:Expression) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq() + def reassembleStatement(in: Seq[Statement]): Statement = this + def disasssembleExpression: Seq[Expression] = Seq(v) + def reassembleExpression(in: Seq[Expression]): Statement = Void(in(0)) + def toString(prefix: String) = s"${prefix}void $v" +} +case class Return(v:Expression) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq() + def reassembleStatement(in: Seq[Statement]): Statement = this + def disasssembleExpression: Seq[Expression] = Seq(v) + def reassembleExpression(in: Seq[Expression]): Statement = Return(in(0)) + def toString(prefix: String) = s"${prefix}return $v" +} +case class Error(msg:String) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq() + def reassembleStatement(in: Seq[Statement]): Statement = this + def disasssembleExpression: Seq[Expression] = Seq() + def reassembleExpression(in: Seq[Expression]): Statement = this + def toString(prefix: String) = prefix+"error: "+msg +} +case class Comment(msg:String) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq() + def reassembleStatement(in: Seq[Statement]): Statement = this + def disasssembleExpression: Seq[Expression] = Seq() + def reassembleExpression(in: Seq[Expression]): Statement = this + def toString(prefix: String) = prefix+"rem: "+msg +} +case class SetRemoveFunction(name:String,nodeType:String,v:Expression) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq() + def reassembleStatement(in: Seq[Statement]): Statement = this + def disasssembleExpression: Seq[Expression] = Seq() + def reassembleExpression(in: Seq[Expression]): Statement = this + def toString(prefix: String) = prefix+"removing from set: "+name +} +case class SetAddFunction(name:String,nodeType:String,v:Expression) extends Statement +{ + def disasssembleStatement: Seq[Statement] = Seq() + def reassembleStatement(in: Seq[Statement]): Statement = this + def disasssembleExpression: Seq[Expression] = Seq() + def reassembleExpression(in: Seq[Expression]): Statement = this + def toString(prefix: String) = s"${prefix} adding to set: $v" +} \ No newline at end of file diff --git a/astral/src/com/astraldb/spec/Type.scala b/astral/src/com/astraldb/spec/Type.scala new file mode 100755 index 0000000..ca04fb5 --- /dev/null +++ b/astral/src/com/astraldb/spec/Type.scala @@ -0,0 +1,37 @@ +package com.astraldb.spec + +sealed abstract class Type { + def array = TArray(this) +} + +sealed abstract class PrimType extends Type + +case class TKey() extends PrimType { override def toString = "key" } +case class TRecord() extends Type + +case class TInt() extends PrimType +case class TFloat() extends PrimType +case class TBool() extends PrimType +case class TArray(t:Type) extends Type +case class TStruct(fields:Seq[Field]) extends Type +case class TNodeRef() extends Type +case class TNode(t:String) extends Type +case class TIterator() extends Type +case class THandleRef() extends Type +object Type +{ + def toString(t: Type): String = + t match { + case TKey() => "key" + case TRecord() => "record" + case TInt() => "int" + case TFloat() => "float" + case TBool() => "bool" + case TArray(nested) => s"array[${Type.toString(nested)}]" + case TStruct(fields) => s"struct[${fields.map { _.toString }.mkString{", "}}]" + case TNodeRef() => "noderef" + case THandleRef() => "handleref" + case TNode(t) => s"node[$t]" + case TIterator() => "iterator" + } +} \ No newline at end of file diff --git a/astral/src/com/astraldb/typecheck/FunctionSignature.scala b/astral/src/com/astraldb/typecheck/FunctionSignature.scala new file mode 100755 index 0000000..e6960c7 --- /dev/null +++ b/astral/src/com/astraldb/typecheck/FunctionSignature.scala @@ -0,0 +1,33 @@ +package com.astraldb.typecheck + +import com.astraldb.spec._ + +abstract class FunctionSignature() +{ + def apply(args:Seq[Type]): Option[Type] + def name: String +} + +class FunctionArgError(function:FunctionSignature, args:Seq[Type]) extends Exception +{ + override def toString = + function.toString + " <- (" + args.map { Type.toString(_) }.mkString(", ") + ")" +} + +class SimpleFunctionSignature(val name:String, args:Seq[Type], ret:Option[Type]) extends FunctionSignature +{ + def apply(cmpArgs:Seq[Type]) = + if(args == cmpArgs){ ret; } + else { throw new FunctionArgError(this, cmpArgs) } + override def toString: String = + ret.getOrElse("void").toString+" "+name+"("+args.map { Type.toString(_) }.mkString(", ")+")" +} + +object FunctionSignature +{ + def apply(name:String, args:Seq[Type], ret:Type) = + new SimpleFunctionSignature(name, args, Some(ret)) + def apply(name:String, args:Seq[Type]) = + new SimpleFunctionSignature(name, args, None) + +} \ No newline at end of file diff --git a/astral/src/com/astraldb/typecheck/Typechecker.scala b/astral/src/com/astraldb/typecheck/Typechecker.scala new file mode 100755 index 0000000..57bae59 --- /dev/null +++ b/astral/src/com/astraldb/typecheck/Typechecker.scala @@ -0,0 +1,295 @@ +package com.astraldb.typecheck + +import com.astraldb.spec._ + +class TypeError(msg: String, ctx: Expression, scope:Map[String, Type]) extends Exception(msg+" in "+ctx.toString) +{ + override def toString:String = { + val scopeLen = scope.map { case (k, _ ) => k.length }.max + "-----------------------\n"+msg + " in " + ctx + "\n\n"+"---- Current Scope ----\n"+ + scope.map { case (k, t) => " " + k.padTo(scopeLen, " ").mkString + " <- " + t }.mkString("\n")+ + "\n-----------------------" + } + + def rebind(newCtx:Expression) = new TypeError(msg, newCtx, scope) + def rebind(newCtx:Statement) = new StatementError(msg, Seq(newCtx), scope) +} + +class StatementError(msg: String, ctx: Seq[Statement], scope:Map[String, Type]) extends Exception(msg+" in "+ctx(0).toString) +{ + override def toString:String = { + val scopeLen = scope.map { case (k, _ ) => k.length }.max + "-----------------------\n"+msg + " in " + ctx.head + "\n\n"+"---- Current Scope ----\n"+ + scope.map { case (k, t) => " " + k.padTo(scopeLen, " ").mkString + " <- " + t }.mkString("\n")+ + "\n-----------------------"+ + (if(ctx.length > 1) { "\n -- in -- \n"+ctx.tail.map { _.toString }.mkString("\n -- in -- \n")} else { "" }) + } + def trace(stmt: Statement) = new StatementError(msg, ctx :+ stmt, scope) +} + + +class Typechecker(functions: Map[String, FunctionSignature], nodeTypes: Map[String, Node]) { + + def comparisonCompatible(t1:Type, t2:Type): Boolean = + { + + if(t1 == t2) { return true; } + (t1, t2) match { + case (TIterator(), _) => return comparisonCompatible(TRecord(), t2) + case (_, TIterator()) => return comparisonCompatible(t1, TRecord()) + case (TKey(), TRecord()) => return true + case (TRecord(), TKey()) => return true + case _ => return false + } + } + + def typeOf(e: Expression, scope: Map[String, Type]): Type = + { + val error = (msg:String) => throw new TypeError(msg, e, scope) + val recur = (r:Expression) => try { typeOf(r, scope) } catch { case t: TypeError => throw t.rebind(e) } + e match { + case c:Constant => c.t + case ArraySubscript(arr, _) => { + recur(arr) match { + case TArray(nested) => nested + case _ => error("Subscript of Non-Array: "+arr) + } + } + case StructSubscript(struct, subscript) => { + recur(struct) match { + case TStruct(fields) => + fields.find { _.name.equals(subscript) } match { + case Some(field) => field.t + case None => error("Invalid Struct Subscript: "+subscript) + } + case _ => error("Subscript of Non-Struct: "+struct) + } + } + case NodeSubscript(node, subscript) => { + recur(node) match { + case TNode(nodeType) => + nodeTypes(nodeType).fields.find { _.name.equals(subscript) } match { + case Some(field) => field.t + case None => error("Invalid Node Subscript: "+subscript) + } + case _ => error("Subscript of Non-Node: "+node) + } + } + case Cmp(op, a, b) => { + if(comparisonCompatible(recur(a), recur(b))){ + return TBool() + } else { + error("Invalid Comparison") + } + } + case Arith(_, a, b) => { + (recur(a), recur(b)) match { + case (TInt(), TInt()) => return TInt() + case (TInt(), TFloat()) + | (TFloat(), TInt()) + | (TFloat(), TFloat()) => return TFloat() + case _ => error("Invalid Arithmetic") + } + } + case FunctionCall(name, args) => { + try { + functions.getOrElse(name, { + error("Undefined function") + })(args.map { recur }).getOrElse { error("Using return value from void function") } + } catch { + case e:FunctionArgError => error(e.toString) + } + } + case FunctionalIfThenElse(c, t, e) => { + (recur(c), recur(t), recur(e)) match { + case (TBool(), a, b) if a == b => a + case (TBool(), _, _) => error("Incompatible functional then-else clauses") + case (_, _, _) => error("Non-Boolean if-then-else condition") + } + } + case Var(name) => { + scope.get(name) match { + case Some(t) => t + case None => error(s"Variable '$name' not in scope") + } + } + case WrapNode(target) => { + recur(target) match { + case TNode(_) => TNodeRef() + case _ => error("Can't wrap a non-node") + } + } + case UnWrapHandle(target) => { + recur(target) match { + case THandleRef() => TNodeRef() + //case TNodeRef() => TNodeRef() + case _ => error("Can't unwrap a non-handle") + } + } + case WrapNodeRef(target) => { + recur(target) match { + case TNodeRef() => THandleRef() + case _ => error("Can't wrap a non-node-ref") + } + } + case MakeNode(nodeType, fields) => { + for( (field, expr) <- nodeTypes(nodeType).fields.zip(fields)){ + if(recur(expr) != field.t) { + error("Invalid node constructor") + } + } + TNode(nodeType) + } + case NodeCast(nodeType,node, subscript) => { + recur(node) match { + case TNode(nodeType) => + nodeTypes(nodeType).fields.find { _.name.equals(subscript) } match { + case Some(field) => field.t + case None => error("Invalid Node Subscript: "+subscript) + } + case _ => error("Subscript of Non-Node: "+node) + } + + } + } + } + + def check(stmt: Statement, scope: Map[String, Type], returnType:Option[Type]): Map[String, Type] = + { + val exprType = (e:Expression) => try { typeOf(e, scope) } catch { case e:TypeError => throw e.rebind(stmt) } + val error = (msg:String) => throw new StatementError(msg, Seq(stmt), scope) + val recur = (rstmt: Statement, rscope: Map[String, Type]) => try { check(rstmt, rscope, returnType) } catch { case e:StatementError => throw e.trace(stmt) } + stmt match { + case Block(elems) => { + elems.foldLeft(scope) { (currScope, currStmt) => + recur(currStmt, currScope) + } + scope + } + case Assign(tgt, expr, false) => { + val tgtType = scope.getOrElse(tgt, { error("Assignment to undefined variable: "+tgt) }) + + if(exprType(expr) != tgtType){ + error("Assignment to "+tgt+" of incorrect type") + } + if(tgtType == TNodeRef()) + { + error("Non-atomic assignment to a NodeRef") + } + if(tgtType == THandleRef()){ + error("Non-atomic assignment to a NodeHandle") + } + scope + } + + case Assign(tgt,expr,true) => { + val tgtType = scope.getOrElse(tgt, { error("Assignment to undefined variable: "+tgt) }) + if(tgtType != THandleRef() || exprType(expr)!= TNodeRef()) + { + error("Atomic assignment into wrong types ") + } + scope + } + + case Declare(tgt, tOption, expr) => { + if(scope contains tgt) { + error("Overriding existing variable") + } + val tRet = tOption match { + case Some(t) => if(t != exprType(expr)){ + error("Declaring expression of incorrect type") + } else { t } + case None => exprType(expr) + } + scope + (tgt -> tRet) + } + + case ExtractNode(name, expr, matchers, onFail) => { + if(scope contains name) { + error("Overriding existing variable") + } + val exp = exprType(expr) + if(exp != THandleRef()) { + error("Doesn't evaluate to an (extractable) node/Handle reference") + } + for( (nodeType, handler) <- matchers ){ + if(!(nodeTypes contains nodeType)) { + error(s"Invalid Node Type: '$nodeType'") + } + recur(handler, scope + (name -> TNode(nodeType))) + } + recur(onFail, scope) + scope + } + + case Return(expr) => { + if(exprType(expr) != returnType.getOrElse { + error(s"Invalid Return Type (Found: ${exprType(expr)}; Void Function)") + }) { + error(s"Invalid Return Type (Found: ${exprType(expr)}; Expected: $returnType)") + } + scope + } + case Void(expr @ FunctionCall(name, args)) => { + try { + functions.getOrElse(name, { + error("Undefined function") + })(args.map { exprType(_) }) + } catch { + case e:FunctionArgError => error(e.toString) + } + scope + } + case Void(expr) => { + exprType(expr) + scope + } + case ForEach(loopvar, expr, body) => { + if(scope contains loopvar) { + error("Overriding existing variable") + } + exprType(expr) match { + case TArray(nested) => + recur(body, scope + (loopvar -> nested)) + case _ => + error("Invalid loop target") + } + scope + } + case IfThenElse(c, t, e) => { + if(exprType(c) != TBool()){ + error("Invalid if-then-else condition") + } + recur(t, scope) + recur(e, scope) + scope + } + case Error(_) => scope + case Comment(_) => scope + case SetRemoveFunction(_,_,_) => scope + case SetAddFunction(_,_,_) => scope + + } + } + + def check(globals: Map[String,Type])(fn: FunctionDefinition): FunctionDefinition = + { + check(fn.body, globals ++ fn.args.map { case (name, t, _) => name -> t }.toMap, fn.ret) + return fn + } + def check(globals: (String,Type)*)(fn: FunctionDefinition): FunctionDefinition = + check(globals.toMap)(fn) + + def check(fn: FunctionDefinition): FunctionDefinition = + { + check(fn.body, fn.args.map { case (name, t, _) => name -> t }.toMap, fn.ret) + return fn + } + + def withFunctions(newFuncs: Map[String, FunctionSignature]): Typechecker = + new Typechecker(functions ++ newFuncs, nodeTypes) + + def withFunctions(newFuncs: (String, FunctionSignature)*): Typechecker = + withFunctions(newFuncs.toMap) + +} \ No newline at end of file diff --git a/build.sc b/build.sc new file mode 100644 index 0000000..2b7095c --- /dev/null +++ b/build.sc @@ -0,0 +1,30 @@ +import mill._ +import mill.scalalib._ +import mill.scalalib.publish._ + +object astral extends ScalaModule with PublishModule { + val VERSION = "0.0.1-SNAPSHOT" + + def scalaVersion = "3.2.1" + + def mainClass = Some("com.astraldb.Astral") + + /************************************************* + *** Backend Dependencies + *************************************************/ + // def ivyDeps = Agg( + // ) + + def publishVersion = VERSION + override def pomSettings = PomSettings( + description = "The Astral Compiler", + organization = "com.astraldb", + url = "http://astraldb.com", + licenses = Seq(License.`Apache-2.0`), + versionControl = VersionControl.github("UBOdin", "astral"), + developers = Seq( + Developer("okennedy", "Oliver Kennedy", "https://odin.cse.buffalo.edu"), + ) + ) +} +