151 lines
5.2 KiB
Scala
151 lines
5.2 KiB
Scala
package org.mimirdb.pip
|
|
package mimirdb
|
|
|
|
|
|
import org.apache.spark.sql.SparkSession
|
|
import org.apache.spark.sql.functions
|
|
import org.apache.spark.sql.functions._
|
|
import scala.util.Random
|
|
import org.scalatest.flatspec.AnyFlatSpec
|
|
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
|
|
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
|
import distribution._
|
|
|
|
/* Spark Session used across tests */
|
|
object DbServer {
|
|
val spark = SparkSession.builder
|
|
.appName("pip")
|
|
.master("local[*]")
|
|
.getOrCreate()
|
|
|
|
spark.sparkContext.setLogLevel("WARN")
|
|
|
|
Pip.init(spark)
|
|
}
|
|
|
|
import DbServer._
|
|
|
|
/* Creating and retrieving GaussType objects in spark dataframe */
|
|
class CreateGaussObject extends AnyFlatSpec {
|
|
|
|
|
|
val df = spark.range(10)
|
|
val sd_window = 3
|
|
val mean_window = 9
|
|
val dfRandData = df.select(rand() * mean_window as "sigma", rand() * sd_window as "mean")
|
|
val dfGaussObj =
|
|
dfRandData.select(
|
|
Gaussian(dfRandData("mean"), dfRandData("sigma")) as "gObj"
|
|
)
|
|
|
|
"A dataframe of Gauss objects" should "have values of type Gauss" in {
|
|
assert(dfGaussObj.schema("gObj").dataType == udt.RandomVariableType)
|
|
}
|
|
|
|
}
|
|
|
|
/* Adding Gauss objects */
|
|
class GaussPlusTests extends AnyFlatSpec {
|
|
"Adding two Gauss objects" should "return a correctly computed Gauss object" in {
|
|
val firstParams = Gaussian.Params(2.0, 1.0)
|
|
val secondParams = Gaussian.Params(1.0, 1.0)
|
|
val sum = Gaussian.plus(firstParams, secondParams)
|
|
assert(sum.mean == 3.0)
|
|
assert(sum.sd == math.sqrt(2.0))
|
|
}
|
|
}
|
|
|
|
/* Subtracting Gauss objects */
|
|
class GaussMinusTests extends AnyFlatSpec {
|
|
"Subtracting two Gauss objects" should "return a correctly computed Gauss object" in {
|
|
val firstParams = Gaussian.Params(2.0, 1.0)
|
|
val secondParams = Gaussian.Params(1.0, 1.0)
|
|
val diff = Gaussian.minus(firstParams, secondParams)
|
|
assert(diff.mean == 1.0)
|
|
assert(diff.sd == math.sqrt(2.0))
|
|
}
|
|
}
|
|
|
|
/* Multiplying Gauss objects */
|
|
class GaussMultTests extends AnyFlatSpec {
|
|
"Multiplying two Gauss objects" should "return a correctly computed (according to code documentation) Gauss object" in {
|
|
/*
|
|
NOTE:
|
|
This code seems redundant. It is just performing the same computation as found in Gauss.gaussMult
|
|
*/
|
|
val firstParams = Gaussian.Params(2.0, 1.0)
|
|
val secondParams = Gaussian.Params(1.0, 1.0)
|
|
val divisor = math.pow(firstParams.sd, 2) + math.pow(secondParams.sd, 2)
|
|
val m_dividend = firstParams.mean * math.pow(secondParams.sd, 2) +
|
|
secondParams.mean * math.pow(firstParams.sd, 2)
|
|
val s_dividend = math.pow(firstParams.sd, 2) * math.pow(secondParams.sd, 2)
|
|
assert(Gaussian.mult(firstParams,secondParams) === Gaussian.Params((m_dividend/divisor), (s_dividend/divisor)))
|
|
|
|
}
|
|
}
|
|
|
|
/* Converting Gauss object to string */
|
|
class GaussToStringTests extends AnyFlatSpec {
|
|
"Calling toString on Gauss object" should "produce a correct string representation" in {
|
|
val gaussParams = Gaussian.Params(2.0, 1.0)
|
|
assert(Gaussian.describe(gaussParams) === s"Gauss(mean: ${gaussParams.mean}, std-dev: ${gaussParams.sd})")
|
|
}
|
|
}
|
|
|
|
/* Histogram Aggregator output */
|
|
class HistAggTest extends AnyFlatSpec {
|
|
/* Generate (appropriate) test data */
|
|
|
|
val id = 1//currently histogram assumes all tuples have the same location id
|
|
def speedGen = scala.util.Random.nextInt(100)//generate random speed in [0, 100)
|
|
def probGen = scala.util.Random.nextDouble()
|
|
val intvls = List.range(10, 110, 10)//(10, 20, 30, 40, 50, 60, 70, 80, 90, 100)//synthetic intervals
|
|
def r_idx = scala.util.Random.nextInt(intvls.length)
|
|
|
|
/* Generate synthetic data frame */
|
|
val idName = "location_id"
|
|
val speedName = "speed"
|
|
val intvlName = "bin"
|
|
val probName = "probability_mass"
|
|
val numRows = 10
|
|
val df = spark.createDataFrame(
|
|
spark.sparkContext.parallelize(
|
|
Seq.fill(numRows){
|
|
val speed = speedGen
|
|
val itvl = intvls.filter(x => x > speed).min
|
|
(id, speed, itvl, probGen)
|
|
}
|
|
)
|
|
).toDF(idName, speedName, intvlName, probName)
|
|
// df.show()
|
|
|
|
"Calling histogram aggregate on dataframe" should "produce correct Array[Double] encoding of histogram" in {
|
|
|
|
df.createOrReplaceTempView("testData")
|
|
val out = spark.sql(s"SELECT histogram($idName, $speedName, $intvlName, $probName) AS hist FROM testData")
|
|
|
|
/*
|
|
out.show()
|
|
|
|
out.take(1).foreach(x => println(x.getAs[Seq[Double]]("hist").mkString(" ")))
|
|
out.foreach(x => println(x))
|
|
|
|
*/
|
|
|
|
val outArr = out.take(1).map{_.getAs[Seq[Double]]("hist")}.flatten
|
|
|
|
|
|
//validation
|
|
val valid = spark.sql(s"SELECT $intvlName, SUM($probName) AS s_prob FROM testData GROUP BY $intvlName ORDER BY $intvlName")
|
|
//valid.show()
|
|
val list_v = valid.collect().map{_.getAs[Double]("s_prob")}
|
|
//list_v.map(println(_))
|
|
|
|
assert(outArr.sameElements(list_v))
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} |