Initial stab, borrowing stuff from the JITD compiler

This commit is contained in:
Oliver Kennedy 2023-07-02 17:53:56 -04:00
commit 16d51b71a5
Signed by: okennedy
GPG key ID: 3E5F9B3ABD3FDB60
14 changed files with 1038 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
.bloop
.metals
/out

View file

@ -0,0 +1,9 @@
package com.astraldb
object Astral
{
def main(args: Array[String]): Unit =
{
println("Hello!")
}
}

View file

@ -0,0 +1,32 @@
package com.astraldb.ast
import com.astraldb.spec.Constant
sealed trait AST
{
def children: Seq[AST]
def find[T](f: AST => Option[T]): Option[T] =
{
f(this).orElse {
children.foldLeft(None:Option[T]){ (ret, child) =>
if(ret.isEmpty) { child.find(f) }
else { ret }
}
}
}
}
case class AList(elements: Seq[AST]) extends AST
{
def children: Seq[AST] = elements
}
case class ANode(label: String, elements: Seq[AST]) extends AST
{
def children: Seq[AST] = elements
}
case class Leaf(value: Constant) extends AST
{
def children: Seq[AST] = Seq.empty
}

View file

@ -0,0 +1,10 @@
package com.astraldb.spec;
import com.astraldb.typecheck._
case class Definition(
nodes:Seq[Node],
// rules:Seq[Rule]
) {
}

View file

@ -0,0 +1,178 @@
package com.astraldb.spec
object CmpTypes extends Enumeration {
type T = Value
val Eq, Neq, Lt, Lte, Gt, Gte = Value
def opString(op:T):String =
{
op match {
case Eq => "=="
case Neq => "!="
case Lt => "<"
case Lte => "<="
case Gt => ">"
case Gte => ">="
}
}
}
object ArithTypes extends Enumeration {
type T = Value
val Add, Sub, Mul, Div, And, Or = Value
def opString(op:T):String =
{
op match {
case Add => "+"
case Sub => "-"
case Mul => "*"
case Div => "/"
case And => "&&"
case Or => "||"
}
}
}
sealed abstract class Expression
{
def eq(other:Expression) = Cmp(CmpTypes.Eq, this, other)
def neq(other:Expression) = Cmp(CmpTypes.Neq, this, other)
def lt(other:Expression) = Cmp(CmpTypes.Lt, this, other)
def lte(other:Expression) = Cmp(CmpTypes.Lte, this, other)
def gt(other:Expression) = Cmp(CmpTypes.Gt, this, other)
def gte(other:Expression) = Cmp(CmpTypes.Gte, this, other)
def and(other:Expression) =
{
(this, other) match {
case (BoolConstant(true), _) => other
case (BoolConstant(false), _) => this
case (_, BoolConstant(true)) => this
case (_, BoolConstant(false)) => other
case _ => Arith(ArithTypes.And, this, other)
}
}
def or(other:Expression) =
{
(this, other) match {
case (BoolConstant(false), _) => other
case (BoolConstant(true), _) => this
case (_, BoolConstant(false)) => this
case (_, BoolConstant(true)) => other
case _ => Arith(ArithTypes.Or, this, other)
}
}
def plus(other:Expression) = Arith(ArithTypes.Add, this, other)
def minus(other:Expression) = Arith(ArithTypes.Sub, this, other)
def times(other:Expression) = Arith(ArithTypes.Mul, this, other)
def dividedBy(other:Expression) = Arith(ArithTypes.Div, this, other)
def get(field:String) = StructSubscript(this, field)
def get(index:Int) = ArraySubscript(this, index)
def disassemble: Seq[Expression]
def reassemble(in: Seq[Expression]): Expression
def rebuild(fn:(Expression => Expression)): Expression =
reassemble(disassemble.map { fn(_) })
}
sealed abstract class Constant(val t:Type) extends Expression
{
def disassemble = Seq[Expression]()
def reassemble(in: Seq[Expression]): Expression = this
}
case class IntConstant(i:Integer) extends Constant(TInt())
{
override def toString:String = i.toString
}
case class FloatConstant(f:Double) extends Constant(TFloat())
{
override def toString:String = f.toString
}
case class BoolConstant(b:Boolean) extends Constant(TBool())
{
override def toString:String = if(b){ "true" } else { "false" }
}
case class Var(v:String) extends Expression
{
def disassemble = Seq[Expression]()
def reassemble(in: Seq[Expression]): Expression = this
override def toString = v.toString
}
case class Cmp(t: CmpTypes.T, lhs:Expression, rhs:Expression) extends Expression
{
def disassemble = Seq[Expression](lhs, rhs)
def reassemble(in: Seq[Expression]): Expression = Cmp(t, in(0), in(1))
override def toString = s"($lhs) ${CmpTypes.opString(t)} ($rhs)"
}
case class Arith(t: ArithTypes.T, lhs:Expression, rhs:Expression) extends Expression
{
def disassemble = Seq[Expression](lhs, rhs)
def reassemble(in: Seq[Expression]): Expression = Arith(t, in(0), in(1))
override def toString = s"($lhs) ${ArithTypes.opString(t)} ($rhs)"
}
case class FunctionalIfThenElse(condition:Expression, ifTrue:Expression, ifFalse:Expression) extends Expression
{
def disassemble = Seq[Expression](condition, ifTrue, ifFalse)
def reassemble(in: Seq[Expression]): Expression = FunctionalIfThenElse(in(0), in(1), in(2))
override def toString = s"($condition) ? ($ifTrue) : ($ifFalse)"
}
case class ArraySubscript(target:Expression, index:Integer) extends Expression
{
def disassemble = Seq[Expression](target)
def reassemble(in: Seq[Expression]): Expression = ArraySubscript(in(0), index)
override def toString = s"$target[$index]"
}
case class StructSubscript(target:Expression, field:String) extends Expression
{
def disassemble = Seq[Expression](target)
def reassemble(in: Seq[Expression]): Expression = StructSubscript(in(0), field)
override def toString = s"$target.$field"
}
case class NodeSubscript(target:Expression, field:String) extends Expression
{
def disassemble = Seq[Expression](target)
def reassemble(in: Seq[Expression]): Expression = NodeSubscript(in(0), field)
override def toString = s"$target->$field"
}
case class NodeCast(nodeType: String,target:Expression, field:String) extends Expression
{
def disassemble = Seq[Expression](target)
def reassemble(in: Seq[Expression]): Expression = NodeSubscript(in(0), field)
override def toString = s"($nodeType *)$target->$field"
}
case class FunctionCall(name:String, args:Seq[Expression]) extends Expression
{
def disassemble = args
def reassemble(in: Seq[Expression]): Expression = FunctionCall(name, in)
override def toString = s"$name(${args.mkString(", ")})"
}
case class WrapNode(target: Expression) extends Expression
{
def disassemble = Seq(target)
def reassemble(in: Seq[Expression]): Expression = WrapNode(in(0))
override def toString = s"wrap ${target.toString}"
}
case class UnWrapHandle(target: Expression) extends Expression
{
def disassemble = Seq(target)
def reassemble(in: Seq[Expression]): Expression = UnWrapHandle(in(0))
override def toString = s"unwraphandleref ${target.toString}"
}
case class WrapNodeRef(target: Expression) extends Expression
{
def disassemble = Seq(target)
def reassemble(in: Seq[Expression]): Expression = WrapNodeRef(in(0))
override def toString = s"wrapnoderef ${target.toString}"
}
case class MakeNode(nodeType: String, fields: Seq[Expression]) extends Expression
{
def disassemble = fields
def reassemble(in: Seq[Expression]): Expression = MakeNode(nodeType, in)
override def toString = s"allocate ${nodeType}(${fields.mkString(",")})"
}

View file

@ -0,0 +1,12 @@
package com.astraldb.spec
import scala.language.implicitConversions
case class Field(name:String, t:Type)
{
override def toString = s"$name:${Type.toString(t)}"
}
object FieldConversions {
implicit def tuple2Field(t:(String, Type)): Field = Field(t._1, t._2)
}

View file

@ -0,0 +1,36 @@
package com.astraldb.spec
import com.astraldb.typecheck.FunctionSignature
object FunctionArgType extends Enumeration
{
type T = Value
val Input, OutputRef, ConstInputRef = Value
def isConst(t: T): Boolean =
t match {
case ConstInputRef => true
case Input | OutputRef => false
}
def isByRef(t: T): Boolean =
t match {
case ConstInputRef | OutputRef => true
case Input => false
}
}
case class FunctionDefinition(
name: String,
ret: Option[Type],
args: Seq[(String, Type, FunctionArgType.T)],
body: Statement
)
{
def signature =
ret match {
case Some(t) =>
FunctionSignature(name, args.map { _._2 }, t)
case None =>
FunctionSignature(name, args.map { _._2 })
}
}

View file

@ -0,0 +1,12 @@
package com.astraldb.spec;
case class Node(val name:String, val fields:Seq[Field])
{
def renderName = name+"Node"
def enumName = "JITD_NODE_"+name
def scope = fields.map { f => f.name -> f.t }.toMap
override def toString =
name + "(" + fields.map { _.toString }.mkString(", ") + ")"
}

View file

@ -0,0 +1,223 @@
package com.astraldb.spec
import com.astraldb.ast._
import com.astraldb.spec.MatchPattern.Scope
/**
* A rule for matching patterns
*/
sealed trait MatchPattern
{
/**
* Apply the match pattern to a specified abstract syntax tree
*
* Return the updated scope, or None if the pattern does not match
*/
def apply(node: AST, scope: Scope): Option[Scope]
}
object MatchPattern
{
type Scope = Map[String, AST]
}
/**
* Boolean logic over match patterns
*/
sealed trait BooleanMatchPattern extends MatchPattern
/**
* Boolean not.
*/
case class MatchNot(child: MatchPattern) extends BooleanMatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
child(node, scope) match {
case None => Some(Map.empty)
case Some(_) => None
}
}
/**
* Boolean and
*/
case class MatchAnd(children: Seq[MatchPattern]) extends BooleanMatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
children.foldLeft(Some(scope):Option[Scope]){ (scope, child) =>
scope.flatMap { child(node, _) }
}
}
/**
* Boolean or
*/
case class MatchOr(children: Seq[MatchPattern]) extends BooleanMatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
{
children.foldLeft(None: Option[Scope]) { (retScope, child) =>
if(retScope.isDefined) { retScope }
else { child(node, scope) }
}
}
}
/**
* Match a node type.
*
* Return a match if the current AST element is a node of the labeled type,
* and if all child matchers match as well.
*/
case class MatchNode(nodeLabel: String, children: Seq[MatchPattern]) extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
node match {
case ANode(label, elements) if label == nodeLabel
&& elements.size == children.size =>
children.zip(elements).foldLeft(Some(scope):Option[Scope]) { case (scope, (child, element)) =>
scope.flatMap { child(element, _) }
}
case _ => None
}
}
/**
* Match a node, solely based on its type
*
* Return a match if the current AST element is a node of the labeled type.
* Ignore the children.
*/
case class MatchNodeType(nodeLabel: String) extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
node match {
case ANode(label, elements) if label == nodeLabel => Some(scope)
case _ => None
}
}
/**
* Match a node exactly.
*
* Return a match if the current AST element is identical to the AST provided
* in the pattern
*/
case class MatchExact(pattern: AST) extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
if(node == pattern){ Some(scope) }
else { None }
}
/**
* Match a descendant identified by a provided path of indices
*
* The provided path specifies descendants by their positions. The 0th
* index is the first child of a Node or a List AST. The matcher returns
* a match if the specified path exists, and if the provided pattern matches
* the element at the specified position.
*/
case class MatchPath(path: Seq[Int], pattern: MatchPattern) extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
path.foldLeft( Some(node):Option[AST] ) { (target, idx) =>
target match {
case Some(ANode(_, elements)) if elements.size > idx => Some(elements(idx))
case Some(AList(elements)) if elements.size > idx => Some(elements(idx))
case _ => None
}
} match {
case None => None
case Some(target) => pattern(target, scope)
}
}
/**
* Match some descendent.
*
* The pattern is applied to all descendents of the current node. The matcher
* returns the first match it finds.
*/
case class MatchRecursive(pattern: MatchPattern) extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
node.find { pattern(_, scope) }
}
/**
* Bind the current node to the scope
*
* The current node is bound to the scope and the provided pattern is applied
* as normal.
*/
case class MatchBind(symbol: String, pattern: MatchPattern) extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
pattern(node, scope ++ Map(symbol -> node))
}
/**
* Always return a match
*/
case object MatchAny extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
Some(scope)
}
/**
* Never return a match
*/
case object MatchNone extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
None
}
/**
* Match the current node against an element of the scope.
*
* Returns a match if the provided symbol exists in the scope and the current
* node is exactly equal to the symbol's value in the scope.
*/
case class MatchLookup(symbol: String) extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
scope.get(symbol).flatMap { scopeNode =>
if(node == scopeNode){ Some(scope) } else { None }
}
}
/**
* Apply a matcher to an element of the scope
*
* Returns a match if the provided symbol exists in the scope and the provided
* pattern exactly matches the symbol's value in the scope.
*/
case class MatchScope(symbol: String, pattern: MatchPattern) extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
scope.get(symbol).flatMap { scopeNode =>
pattern(scopeNode, scope)
}
}
/**
* Apply a matcher to every element of a list
*
* Returns a match if the current element is a list and every element of the
* list matches the provided pattern
*/
case class MatchList(pattern: MatchPattern) extends MatchPattern
{
def apply(node: AST, scope: Scope): Option[Scope] =
node match {
case AList(elements) =>
elements.foldLeft( Some(scope): Option[Scope] ){ (scope, element) =>
scope.flatMap { pattern(element, _) }
}
case _ => None
}
}

View file

@ -0,0 +1,128 @@
package com.astraldb.spec
import scala.language.implicitConversions
object StatementConversions {
implicit def seq2block(s:Seq[Statement]): Block = Block(s)
}
sealed abstract class Statement
{
def disasssembleStatement: Seq[Statement]
def reassembleStatement(in: Seq[Statement]): Statement
def rebuildStatement(fn: Statement => Statement): Statement =
reassembleStatement(disasssembleStatement.map { fn(_) })
def disasssembleExpression: Seq[Expression]
def reassembleExpression(in: Seq[Expression]): Statement
def rebuildExpression(fn: Expression => Expression): Statement =
reassembleExpression(disasssembleExpression.map { fn(_) })
.rebuildStatement { _.rebuildExpression(fn) }
def toString(prefix: String): String
override def toString: String = toString("")
def blockSeq: Seq[Statement] = Seq(this)
def ++(other: Statement): Block = Block(this.blockSeq ++ other.blockSeq)
}
case class IfThenElse(condition:Expression, ifTrue:Statement, ifFalse:Statement = Block(Seq())) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq(ifTrue, ifFalse)
def reassembleStatement(in: Seq[Statement]): Statement = IfThenElse(condition, in(0), in(1))
def disasssembleExpression: Seq[Expression] = Seq(condition)
def reassembleExpression(in: Seq[Expression]): Statement = IfThenElse(in(0), ifTrue, ifFalse)
def toString(prefix: String) = s"${prefix}if($condition)\n"+ifTrue.toString(prefix+" ")+s"\n${prefix}else\n"+ifFalse.toString(prefix+" ")
}
case class Declare(name:String, t:Option[Type], v:Expression) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq()
def reassembleStatement(in: Seq[Statement]): Statement = this
def disasssembleExpression: Seq[Expression] = Seq(v)
def reassembleExpression(in: Seq[Expression]): Statement = Declare(name, t, in(0))
def toString(prefix: String) = s"${prefix}var $name${t.map { x => ":"+Type.toString(x)}.getOrElse("")} = $v"
}
case class Assign(name:String, v:Expression, atomic:Boolean = false) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq()
def reassembleStatement(in: Seq[Statement]): Statement = this
def disasssembleExpression: Seq[Expression] = Seq(v)
def reassembleExpression(in: Seq[Expression]): Statement = Assign(name, in(0), atomic)
def toString(prefix: String) = s"${prefix}$name = $v"
}
case class ExtractNode(name:String, v:Expression, nodeHandlers: Seq[(String, Statement)], onFail: Statement) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq(onFail) ++ nodeHandlers.map { _._2 }
def reassembleStatement(in: Seq[Statement]): Statement = ExtractNode(name, v, nodeHandlers.zip(in.tail).map { case ((name, _), handler) => (name, handler) }, in(0))
def disasssembleExpression: Seq[Expression] = Seq(v)
def reassembleExpression(in: Seq[Expression]): Statement = ExtractNode(name, in(0), nodeHandlers, onFail)
def toString(prefix: String) = s"${prefix}extract $v into $name { \n"+nodeHandlers.map {
case (nodeType, onMatch) => s"${prefix} case $nodeType -> \n${onMatch.toString(prefix+" ")}\n"
}.mkString+s"\n${prefix} else -> \n${onFail.toString(prefix+" ")}\n${prefix}}"
}
case class Block(statements:Seq[Statement]) extends Statement
{
def disasssembleStatement: Seq[Statement] = statements
def reassembleStatement(in: Seq[Statement]): Statement = Block(in)
def disasssembleExpression: Seq[Expression] = Seq()
def reassembleExpression(in: Seq[Expression]): Statement = this
def toString(prefix: String) = (Seq(prefix+"{")++statements.map{ _.toString(prefix+" ") }++Seq(prefix+"}")).mkString("\n")
override def blockSeq = statements
}
case class ForEach(loopvar:String, over:Expression, body:Statement) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq(body)
def reassembleStatement(in: Seq[Statement]): Statement = ForEach(loopvar, over, in(0))
def disasssembleExpression: Seq[Expression] = Seq(over)
def reassembleExpression(in: Seq[Expression]): Statement = ForEach(loopvar, in(0), body)
def toString(prefix: String) = s"${prefix}for($loopvar in $over)\n${body.toString(prefix+" ")}"
}
case class Void(v:Expression) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq()
def reassembleStatement(in: Seq[Statement]): Statement = this
def disasssembleExpression: Seq[Expression] = Seq(v)
def reassembleExpression(in: Seq[Expression]): Statement = Void(in(0))
def toString(prefix: String) = s"${prefix}void $v"
}
case class Return(v:Expression) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq()
def reassembleStatement(in: Seq[Statement]): Statement = this
def disasssembleExpression: Seq[Expression] = Seq(v)
def reassembleExpression(in: Seq[Expression]): Statement = Return(in(0))
def toString(prefix: String) = s"${prefix}return $v"
}
case class Error(msg:String) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq()
def reassembleStatement(in: Seq[Statement]): Statement = this
def disasssembleExpression: Seq[Expression] = Seq()
def reassembleExpression(in: Seq[Expression]): Statement = this
def toString(prefix: String) = prefix+"error: "+msg
}
case class Comment(msg:String) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq()
def reassembleStatement(in: Seq[Statement]): Statement = this
def disasssembleExpression: Seq[Expression] = Seq()
def reassembleExpression(in: Seq[Expression]): Statement = this
def toString(prefix: String) = prefix+"rem: "+msg
}
case class SetRemoveFunction(name:String,nodeType:String,v:Expression) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq()
def reassembleStatement(in: Seq[Statement]): Statement = this
def disasssembleExpression: Seq[Expression] = Seq()
def reassembleExpression(in: Seq[Expression]): Statement = this
def toString(prefix: String) = prefix+"removing from set: "+name
}
case class SetAddFunction(name:String,nodeType:String,v:Expression) extends Statement
{
def disasssembleStatement: Seq[Statement] = Seq()
def reassembleStatement(in: Seq[Statement]): Statement = this
def disasssembleExpression: Seq[Expression] = Seq()
def reassembleExpression(in: Seq[Expression]): Statement = this
def toString(prefix: String) = s"${prefix} adding to set: $v"
}

View file

@ -0,0 +1,37 @@
package com.astraldb.spec
sealed abstract class Type {
def array = TArray(this)
}
sealed abstract class PrimType extends Type
case class TKey() extends PrimType { override def toString = "key" }
case class TRecord() extends Type
case class TInt() extends PrimType
case class TFloat() extends PrimType
case class TBool() extends PrimType
case class TArray(t:Type) extends Type
case class TStruct(fields:Seq[Field]) extends Type
case class TNodeRef() extends Type
case class TNode(t:String) extends Type
case class TIterator() extends Type
case class THandleRef() extends Type
object Type
{
def toString(t: Type): String =
t match {
case TKey() => "key"
case TRecord() => "record"
case TInt() => "int"
case TFloat() => "float"
case TBool() => "bool"
case TArray(nested) => s"array[${Type.toString(nested)}]"
case TStruct(fields) => s"struct[${fields.map { _.toString }.mkString{", "}}]"
case TNodeRef() => "noderef"
case THandleRef() => "handleref"
case TNode(t) => s"node[$t]"
case TIterator() => "iterator"
}
}

View file

@ -0,0 +1,33 @@
package com.astraldb.typecheck
import com.astraldb.spec._
abstract class FunctionSignature()
{
def apply(args:Seq[Type]): Option[Type]
def name: String
}
class FunctionArgError(function:FunctionSignature, args:Seq[Type]) extends Exception
{
override def toString =
function.toString + " <- (" + args.map { Type.toString(_) }.mkString(", ") + ")"
}
class SimpleFunctionSignature(val name:String, args:Seq[Type], ret:Option[Type]) extends FunctionSignature
{
def apply(cmpArgs:Seq[Type]) =
if(args == cmpArgs){ ret; }
else { throw new FunctionArgError(this, cmpArgs) }
override def toString: String =
ret.getOrElse("void").toString+" "+name+"("+args.map { Type.toString(_) }.mkString(", ")+")"
}
object FunctionSignature
{
def apply(name:String, args:Seq[Type], ret:Type) =
new SimpleFunctionSignature(name, args, Some(ret))
def apply(name:String, args:Seq[Type]) =
new SimpleFunctionSignature(name, args, None)
}

View file

@ -0,0 +1,295 @@
package com.astraldb.typecheck
import com.astraldb.spec._
class TypeError(msg: String, ctx: Expression, scope:Map[String, Type]) extends Exception(msg+" in "+ctx.toString)
{
override def toString:String = {
val scopeLen = scope.map { case (k, _ ) => k.length }.max
"-----------------------\n"+msg + " in " + ctx + "\n\n"+"---- Current Scope ----\n"+
scope.map { case (k, t) => " " + k.padTo(scopeLen, " ").mkString + " <- " + t }.mkString("\n")+
"\n-----------------------"
}
def rebind(newCtx:Expression) = new TypeError(msg, newCtx, scope)
def rebind(newCtx:Statement) = new StatementError(msg, Seq(newCtx), scope)
}
class StatementError(msg: String, ctx: Seq[Statement], scope:Map[String, Type]) extends Exception(msg+" in "+ctx(0).toString)
{
override def toString:String = {
val scopeLen = scope.map { case (k, _ ) => k.length }.max
"-----------------------\n"+msg + " in " + ctx.head + "\n\n"+"---- Current Scope ----\n"+
scope.map { case (k, t) => " " + k.padTo(scopeLen, " ").mkString + " <- " + t }.mkString("\n")+
"\n-----------------------"+
(if(ctx.length > 1) { "\n -- in -- \n"+ctx.tail.map { _.toString }.mkString("\n -- in -- \n")} else { "" })
}
def trace(stmt: Statement) = new StatementError(msg, ctx :+ stmt, scope)
}
class Typechecker(functions: Map[String, FunctionSignature], nodeTypes: Map[String, Node]) {
def comparisonCompatible(t1:Type, t2:Type): Boolean =
{
if(t1 == t2) { return true; }
(t1, t2) match {
case (TIterator(), _) => return comparisonCompatible(TRecord(), t2)
case (_, TIterator()) => return comparisonCompatible(t1, TRecord())
case (TKey(), TRecord()) => return true
case (TRecord(), TKey()) => return true
case _ => return false
}
}
def typeOf(e: Expression, scope: Map[String, Type]): Type =
{
val error = (msg:String) => throw new TypeError(msg, e, scope)
val recur = (r:Expression) => try { typeOf(r, scope) } catch { case t: TypeError => throw t.rebind(e) }
e match {
case c:Constant => c.t
case ArraySubscript(arr, _) => {
recur(arr) match {
case TArray(nested) => nested
case _ => error("Subscript of Non-Array: "+arr)
}
}
case StructSubscript(struct, subscript) => {
recur(struct) match {
case TStruct(fields) =>
fields.find { _.name.equals(subscript) } match {
case Some(field) => field.t
case None => error("Invalid Struct Subscript: "+subscript)
}
case _ => error("Subscript of Non-Struct: "+struct)
}
}
case NodeSubscript(node, subscript) => {
recur(node) match {
case TNode(nodeType) =>
nodeTypes(nodeType).fields.find { _.name.equals(subscript) } match {
case Some(field) => field.t
case None => error("Invalid Node Subscript: "+subscript)
}
case _ => error("Subscript of Non-Node: "+node)
}
}
case Cmp(op, a, b) => {
if(comparisonCompatible(recur(a), recur(b))){
return TBool()
} else {
error("Invalid Comparison")
}
}
case Arith(_, a, b) => {
(recur(a), recur(b)) match {
case (TInt(), TInt()) => return TInt()
case (TInt(), TFloat())
| (TFloat(), TInt())
| (TFloat(), TFloat()) => return TFloat()
case _ => error("Invalid Arithmetic")
}
}
case FunctionCall(name, args) => {
try {
functions.getOrElse(name, {
error("Undefined function")
})(args.map { recur }).getOrElse { error("Using return value from void function") }
} catch {
case e:FunctionArgError => error(e.toString)
}
}
case FunctionalIfThenElse(c, t, e) => {
(recur(c), recur(t), recur(e)) match {
case (TBool(), a, b) if a == b => a
case (TBool(), _, _) => error("Incompatible functional then-else clauses")
case (_, _, _) => error("Non-Boolean if-then-else condition")
}
}
case Var(name) => {
scope.get(name) match {
case Some(t) => t
case None => error(s"Variable '$name' not in scope")
}
}
case WrapNode(target) => {
recur(target) match {
case TNode(_) => TNodeRef()
case _ => error("Can't wrap a non-node")
}
}
case UnWrapHandle(target) => {
recur(target) match {
case THandleRef() => TNodeRef()
//case TNodeRef() => TNodeRef()
case _ => error("Can't unwrap a non-handle")
}
}
case WrapNodeRef(target) => {
recur(target) match {
case TNodeRef() => THandleRef()
case _ => error("Can't wrap a non-node-ref")
}
}
case MakeNode(nodeType, fields) => {
for( (field, expr) <- nodeTypes(nodeType).fields.zip(fields)){
if(recur(expr) != field.t) {
error("Invalid node constructor")
}
}
TNode(nodeType)
}
case NodeCast(nodeType,node, subscript) => {
recur(node) match {
case TNode(nodeType) =>
nodeTypes(nodeType).fields.find { _.name.equals(subscript) } match {
case Some(field) => field.t
case None => error("Invalid Node Subscript: "+subscript)
}
case _ => error("Subscript of Non-Node: "+node)
}
}
}
}
def check(stmt: Statement, scope: Map[String, Type], returnType:Option[Type]): Map[String, Type] =
{
val exprType = (e:Expression) => try { typeOf(e, scope) } catch { case e:TypeError => throw e.rebind(stmt) }
val error = (msg:String) => throw new StatementError(msg, Seq(stmt), scope)
val recur = (rstmt: Statement, rscope: Map[String, Type]) => try { check(rstmt, rscope, returnType) } catch { case e:StatementError => throw e.trace(stmt) }
stmt match {
case Block(elems) => {
elems.foldLeft(scope) { (currScope, currStmt) =>
recur(currStmt, currScope)
}
scope
}
case Assign(tgt, expr, false) => {
val tgtType = scope.getOrElse(tgt, { error("Assignment to undefined variable: "+tgt) })
if(exprType(expr) != tgtType){
error("Assignment to "+tgt+" of incorrect type")
}
if(tgtType == TNodeRef())
{
error("Non-atomic assignment to a NodeRef")
}
if(tgtType == THandleRef()){
error("Non-atomic assignment to a NodeHandle")
}
scope
}
case Assign(tgt,expr,true) => {
val tgtType = scope.getOrElse(tgt, { error("Assignment to undefined variable: "+tgt) })
if(tgtType != THandleRef() || exprType(expr)!= TNodeRef())
{
error("Atomic assignment into wrong types ")
}
scope
}
case Declare(tgt, tOption, expr) => {
if(scope contains tgt) {
error("Overriding existing variable")
}
val tRet = tOption match {
case Some(t) => if(t != exprType(expr)){
error("Declaring expression of incorrect type")
} else { t }
case None => exprType(expr)
}
scope + (tgt -> tRet)
}
case ExtractNode(name, expr, matchers, onFail) => {
if(scope contains name) {
error("Overriding existing variable")
}
val exp = exprType(expr)
if(exp != THandleRef()) {
error("Doesn't evaluate to an (extractable) node/Handle reference")
}
for( (nodeType, handler) <- matchers ){
if(!(nodeTypes contains nodeType)) {
error(s"Invalid Node Type: '$nodeType'")
}
recur(handler, scope + (name -> TNode(nodeType)))
}
recur(onFail, scope)
scope
}
case Return(expr) => {
if(exprType(expr) != returnType.getOrElse {
error(s"Invalid Return Type (Found: ${exprType(expr)}; Void Function)")
}) {
error(s"Invalid Return Type (Found: ${exprType(expr)}; Expected: $returnType)")
}
scope
}
case Void(expr @ FunctionCall(name, args)) => {
try {
functions.getOrElse(name, {
error("Undefined function")
})(args.map { exprType(_) })
} catch {
case e:FunctionArgError => error(e.toString)
}
scope
}
case Void(expr) => {
exprType(expr)
scope
}
case ForEach(loopvar, expr, body) => {
if(scope contains loopvar) {
error("Overriding existing variable")
}
exprType(expr) match {
case TArray(nested) =>
recur(body, scope + (loopvar -> nested))
case _ =>
error("Invalid loop target")
}
scope
}
case IfThenElse(c, t, e) => {
if(exprType(c) != TBool()){
error("Invalid if-then-else condition")
}
recur(t, scope)
recur(e, scope)
scope
}
case Error(_) => scope
case Comment(_) => scope
case SetRemoveFunction(_,_,_) => scope
case SetAddFunction(_,_,_) => scope
}
}
def check(globals: Map[String,Type])(fn: FunctionDefinition): FunctionDefinition =
{
check(fn.body, globals ++ fn.args.map { case (name, t, _) => name -> t }.toMap, fn.ret)
return fn
}
def check(globals: (String,Type)*)(fn: FunctionDefinition): FunctionDefinition =
check(globals.toMap)(fn)
def check(fn: FunctionDefinition): FunctionDefinition =
{
check(fn.body, fn.args.map { case (name, t, _) => name -> t }.toMap, fn.ret)
return fn
}
def withFunctions(newFuncs: Map[String, FunctionSignature]): Typechecker =
new Typechecker(functions ++ newFuncs, nodeTypes)
def withFunctions(newFuncs: (String, FunctionSignature)*): Typechecker =
withFunctions(newFuncs.toMap)
}

30
build.sc Normal file
View file

@ -0,0 +1,30 @@
import mill._
import mill.scalalib._
import mill.scalalib.publish._
object astral extends ScalaModule with PublishModule {
val VERSION = "0.0.1-SNAPSHOT"
def scalaVersion = "3.2.1"
def mainClass = Some("com.astraldb.Astral")
/*************************************************
*** Backend Dependencies
*************************************************/
// def ivyDeps = Agg(
// )
def publishVersion = VERSION
override def pomSettings = PomSettings(
description = "The Astral Compiler",
organization = "com.astraldb",
url = "http://astraldb.com",
licenses = Seq(License.`Apache-2.0`),
versionControl = VersionControl.github("UBOdin", "astral"),
developers = Seq(
Developer("okennedy", "Oliver Kennedy", "https://odin.cse.buffalo.edu"),
)
)
}