[SPARK-22644][ML][TEST] Make ML testsuite support StructuredStreaming test

## What changes were proposed in this pull request?

We need to add some helper code to make testing ML transformers & models easier with streaming data. These tests might help us catch any remaining issues and we could encourage future PRs to use these tests to prevent new Models & Transformers from having issues.

I add a `MLTest` trait which extends `StreamTest` trait, and override `createSparkSession`. So ML testsuite can only extend `MLTest`, to use both ML & Stream test util functions.

I only modify one testcase in `LinearRegressionSuite`, for first pass review.

Link to #19746

## How was this patch tested?

`MLTestSuite` added.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19843 from WeichenXu123/ml_stream_test_helper.
This commit is contained in:
WeichenXu 2017-12-12 21:28:24 -08:00 committed by Joseph K. Bradley
parent c7d0148615
commit 0e36ba6212
6 changed files with 203 additions and 26 deletions

View file

@ -60,6 +60,20 @@
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-graphx_${scala.binary.version}</artifactId>

View file

@ -24,13 +24,12 @@ import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.{DataFrame, Row}
class LinearRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@ -233,7 +232,8 @@ class LinearRegressionSuite
assert(model2.intercept ~== interceptR relTol 1E-3)
assert(model2.coefficients ~= coefficientsR relTol 1E-3)
model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach {
testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
"features", "prediction") {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +

View file

@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import java.io.File
import org.scalatest.Suite
import org.apache.spark.SparkContext
import org.apache.spark.ml.{PipelineModel, Transformer}
import org.apache.spark.sql.{DataFrame, Encoder, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.test.TestSparkSession
import org.apache.spark.util.Utils
trait MLTest extends StreamTest with TempDirectory { self: Suite =>
@transient var sc: SparkContext = _
@transient var checkpointDir: String = _
protected override def createSparkSession: TestSparkSession = {
new TestSparkSession(new SparkContext("local[2]", "MLlibUnitTest", sparkConf))
}
override def beforeAll(): Unit = {
super.beforeAll()
sc = spark.sparkContext
checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
sc.setCheckpointDir(checkpointDir)
}
override def afterAll() {
try {
Utils.deleteRecursively(new File(checkpointDir))
} finally {
super.afterAll()
}
}
def testTransformerOnStreamData[A : Encoder](
dataframe: DataFrame,
transformer: Transformer,
firstResultCol: String,
otherResultCols: String*)
(checkFunction: Row => Unit): Unit = {
val columnNames = dataframe.schema.fieldNames
val stream = MemoryStream[A]
val streamDF = stream.toDS().toDF(columnNames: _*)
val data = dataframe.as[A].collect()
val streamOutput = transformer.transform(streamDF)
.select(firstResultCol, otherResultCols: _*)
testStream(streamOutput) (
AddData(stream, data: _*),
CheckAnswer(checkFunction)
)
}
def testTransformer[A : Encoder](
dataframe: DataFrame,
transformer: Transformer,
firstResultCol: String,
otherResultCols: String*)
(checkFunction: Row => Unit): Unit = {
testTransformerOnStreamData(dataframe, transformer, firstResultCol,
otherResultCols: _*)(checkFunction)
val dfOutput = transformer.transform(dataframe)
dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach { row =>
checkFunction(row)
}
}
}

View file

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import org.apache.spark.ml.{PipelineModel, Transformer}
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.Row
class MLTestSuite extends MLTest {
import testImplicits._
test("test transformer on stream data") {
val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e"), (5, "f"))
.toDF("id", "label")
val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
.setInputCol("label").setOutputCol("indexed")
val indexerModel = indexer.fit(data)
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
case Row(id: Int, indexed: Double) =>
assert(id === indexed.toInt)
}
intercept[Exception] {
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
case Row(id: Int, indexed: Double) =>
assert(id != indexed.toInt)
}
}
}
}

View file

@ -133,6 +133,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
}
def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false)
def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
CheckAnswerRowsByFunc(checkFunction, false)
}
/**
@ -154,6 +157,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
}
def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)
def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
CheckAnswerRowsByFunc(checkFunction, true)
}
case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean)
@ -162,6 +168,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
}
case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean)
extends StreamAction with StreamMustBeRunning {
override def toString: String = s"$operatorName: ${checkFunction.toString()}"
private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc"
}
/** Stops the stream. It must currently be running. */
case object StopStream extends StreamAction with StreamMustBeRunning
@ -352,6 +364,29 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
""".stripMargin)
}
def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = {
verify(currentStream != null, "stream not running")
// Get the map of source index to the current source objects
val indexToSource = currentStream
.logicalPlan
.collect { case StreamingExecutionRelation(s, _) => s }
.zipWithIndex
.map(_.swap)
.toMap
// Block until all data added has been processed for all the source
awaiting.foreach { case (sourceIndex, offset) =>
failAfter(streamingTimeout) {
currentStream.awaitOffset(indexToSource(sourceIndex), offset)
}
}
try if (lastOnly) sink.latestBatchData else sink.allData catch {
case e: Exception =>
failTest("Exception while getting data from sink", e)
}
}
var manualClockExpectedTime = -1L
val defaultCheckpointLocation =
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@ -552,30 +587,20 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
e.runAction()
case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
verify(currentStream != null, "stream not running")
// Get the map of source index to the current source objects
val indexToSource = currentStream
.logicalPlan
.collect { case StreamingExecutionRelation(s, _) => s }
.zipWithIndex
.map(_.swap)
.toMap
// Block until all data added has been processed for all the source
awaiting.foreach { case (sourceIndex, offset) =>
failAfter(streamingTimeout) {
currentStream.awaitOffset(indexToSource(sourceIndex), offset)
}
}
val sparkAnswer = try if (lastOnly) sink.latestBatchData else sink.allData catch {
case e: Exception =>
failTest("Exception while getting data from sink", e)
}
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
error => failTest(error)
}
case CheckAnswerRowsByFunc(checkFunction, lastOnly) =>
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
sparkAnswer.foreach { row =>
try {
checkFunction(row)
} catch {
case e: Throwable => failTest(e.toString)
}
}
}
pos += 1
}

View file

@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf
/**
* A special `SparkSession` prepared for testing.
*/
private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
def this(sparkConf: SparkConf) {
this(new SparkContext("local[2]", "test-sql-context",
sparkConf.set("spark.sql.testkey", "true")))