mimir-pip/lib/test/src/org/mimirdb/pip/GaussTests.scala

151 lines
5.2 KiB
Scala

package org.mimirdb.pip
package mimirdb
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions
import org.apache.spark.sql.functions._
import scala.util.Random
import org.scalatest.flatspec.AnyFlatSpec
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import distribution._
/* Spark Session used across tests */
object DbServer {
val spark = SparkSession.builder
.appName("pip")
.master("local[*]")
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
Pip.init(spark)
}
import DbServer._
/* Creating and retrieving GaussType objects in spark dataframe */
class CreateGaussObject extends AnyFlatSpec {
val df = spark.range(10)
val sd_window = 3
val mean_window = 9
val dfRandData = df.select(rand() * mean_window as "sigma", rand() * sd_window as "mean")
val dfGaussObj =
dfRandData.select(
Gaussian(dfRandData("mean"), dfRandData("sigma")) as "gObj"
)
"A dataframe of Gauss objects" should "have values of type Gauss" in {
assert(dfGaussObj.schema("gObj").dataType == udt.RandomVariableType)
}
}
/* Adding Gauss objects */
class GaussPlusTests extends AnyFlatSpec {
"Adding two Gauss objects" should "return a correctly computed Gauss object" in {
val firstParams = Gaussian.Params(2.0, 1.0)
val secondParams = Gaussian.Params(1.0, 1.0)
val sum = Gaussian.plus(firstParams, secondParams)
assert(sum.mean == 3.0)
assert(sum.sd == math.sqrt(2.0))
}
}
/* Subtracting Gauss objects */
class GaussMinusTests extends AnyFlatSpec {
"Subtracting two Gauss objects" should "return a correctly computed Gauss object" in {
val firstParams = Gaussian.Params(2.0, 1.0)
val secondParams = Gaussian.Params(1.0, 1.0)
val diff = Gaussian.minus(firstParams, secondParams)
assert(diff.mean == 1.0)
assert(diff.sd == math.sqrt(2.0))
}
}
/* Multiplying Gauss objects */
class GaussMultTests extends AnyFlatSpec {
"Multiplying two Gauss objects" should "return a correctly computed (according to code documentation) Gauss object" in {
/*
NOTE:
This code seems redundant. It is just performing the same computation as found in Gauss.gaussMult
*/
val firstParams = Gaussian.Params(2.0, 1.0)
val secondParams = Gaussian.Params(1.0, 1.0)
val divisor = math.pow(firstParams.sd, 2) + math.pow(secondParams.sd, 2)
val m_dividend = firstParams.mean * math.pow(secondParams.sd, 2) +
secondParams.mean * math.pow(firstParams.sd, 2)
val s_dividend = math.pow(firstParams.sd, 2) * math.pow(secondParams.sd, 2)
assert(Gaussian.mult(firstParams,secondParams) === Gaussian.Params((m_dividend/divisor), (s_dividend/divisor)))
}
}
/* Converting Gauss object to string */
class GaussToStringTests extends AnyFlatSpec {
"Calling toString on Gauss object" should "produce a correct string representation" in {
val gaussParams = Gaussian.Params(2.0, 1.0)
assert(Gaussian.describe(gaussParams) === s"Gauss(mean: ${gaussParams.mean}, std-dev: ${gaussParams.sd})")
}
}
/* Histogram Aggregator output */
class HistAggTest extends AnyFlatSpec {
/* Generate (appropriate) test data */
val id = 1//currently histogram assumes all tuples have the same location id
def speedGen = scala.util.Random.nextInt(100)//generate random speed in [0, 100)
def probGen = scala.util.Random.nextDouble()
val intvls = List.range(10, 110, 10)//(10, 20, 30, 40, 50, 60, 70, 80, 90, 100)//synthetic intervals
def r_idx = scala.util.Random.nextInt(intvls.length)
/* Generate synthetic data frame */
val idName = "location_id"
val speedName = "speed"
val intvlName = "bin"
val probName = "probability_mass"
val numRows = 10
val df = spark.createDataFrame(
spark.sparkContext.parallelize(
Seq.fill(numRows){
val speed = speedGen
val itvl = intvls.filter(x => x > speed).min
(id, speed, itvl, probGen)
}
)
).toDF(idName, speedName, intvlName, probName)
// df.show()
"Calling histogram aggregate on dataframe" should "produce correct Array[Double] encoding of histogram" in {
df.createOrReplaceTempView("testData")
val out = spark.sql(s"SELECT histogram($idName, $speedName, $intvlName, $probName) AS hist FROM testData")
/*
out.show()
out.take(1).foreach(x => println(x.getAs[Seq[Double]]("hist").mkString(" ")))
out.foreach(x => println(x))
*/
val outArr = out.take(1).map{_.getAs[Seq[Double]]("hist")}.flatten
//validation
val valid = spark.sql(s"SELECT $intvlName, SUM($probName) AS s_prob FROM testData GROUP BY $intvlName ORDER BY $intvlName")
//valid.show()
val list_v = valid.collect().map{_.getAs[Double]("s_prob")}
//list_v.map(println(_))
assert(outArr.sameElements(list_v))
}
}