[SPARK-22644][ML][TEST] Make ML testsuite support StructuredStreaming test
## What changes were proposed in this pull request? We need to add some helper code to make testing ML transformers & models easier with streaming data. These tests might help us catch any remaining issues and we could encourage future PRs to use these tests to prevent new Models & Transformers from having issues. I add a `MLTest` trait which extends `StreamTest` trait, and override `createSparkSession`. So ML testsuite can only extend `MLTest`, to use both ML & Stream test util functions. I only modify one testcase in `LinearRegressionSuite`, for first pass review. Link to #19746 ## How was this patch tested? `MLTestSuite` added. Author: WeichenXu <weichen.xu@databricks.com> Closes #19843 from WeichenXu123/ml_stream_test_helper.
This commit is contained in:
parent
c7d0148615
commit
0e36ba6212
|
@ -60,6 +60,20 @@
|
||||||
<artifactId>spark-sql_${scala.binary.version}</artifactId>
|
<artifactId>spark-sql_${scala.binary.version}</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.spark</groupId>
|
||||||
|
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<type>test-jar</type>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.spark</groupId>
|
||||||
|
<artifactId>spark-sql_${scala.binary.version}</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<type>test-jar</type>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
<artifactId>spark-graphx_${scala.binary.version}</artifactId>
|
<artifactId>spark-graphx_${scala.binary.version}</artifactId>
|
||||||
|
|
|
@ -24,13 +24,12 @@ import org.apache.spark.ml.feature.Instance
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
|
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
|
||||||
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
|
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
|
||||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
|
||||||
import org.apache.spark.ml.util.TestingUtils._
|
import org.apache.spark.ml.util.TestingUtils._
|
||||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
|
||||||
class LinearRegressionSuite
|
class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
|
||||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
|
||||||
|
|
||||||
import testImplicits._
|
import testImplicits._
|
||||||
|
|
||||||
|
@ -233,7 +232,8 @@ class LinearRegressionSuite
|
||||||
assert(model2.intercept ~== interceptR relTol 1E-3)
|
assert(model2.intercept ~== interceptR relTol 1E-3)
|
||||||
assert(model2.coefficients ~= coefficientsR relTol 1E-3)
|
assert(model2.coefficients ~= coefficientsR relTol 1E-3)
|
||||||
|
|
||||||
model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach {
|
testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
|
||||||
|
"features", "prediction") {
|
||||||
case Row(features: DenseVector, prediction1: Double) =>
|
case Row(features: DenseVector, prediction1: Double) =>
|
||||||
val prediction2 =
|
val prediction2 =
|
||||||
features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
|
features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
|
||||||
|
|
91
mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
Normal file
91
mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.util
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
|
import org.scalatest.Suite
|
||||||
|
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.ml.{PipelineModel, Transformer}
|
||||||
|
import org.apache.spark.sql.{DataFrame, Encoder, Row}
|
||||||
|
import org.apache.spark.sql.execution.streaming.MemoryStream
|
||||||
|
import org.apache.spark.sql.streaming.StreamTest
|
||||||
|
import org.apache.spark.sql.test.TestSparkSession
|
||||||
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
|
trait MLTest extends StreamTest with TempDirectory { self: Suite =>
|
||||||
|
|
||||||
|
@transient var sc: SparkContext = _
|
||||||
|
@transient var checkpointDir: String = _
|
||||||
|
|
||||||
|
protected override def createSparkSession: TestSparkSession = {
|
||||||
|
new TestSparkSession(new SparkContext("local[2]", "MLlibUnitTest", sparkConf))
|
||||||
|
}
|
||||||
|
|
||||||
|
override def beforeAll(): Unit = {
|
||||||
|
super.beforeAll()
|
||||||
|
sc = spark.sparkContext
|
||||||
|
checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
|
||||||
|
sc.setCheckpointDir(checkpointDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def afterAll() {
|
||||||
|
try {
|
||||||
|
Utils.deleteRecursively(new File(checkpointDir))
|
||||||
|
} finally {
|
||||||
|
super.afterAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def testTransformerOnStreamData[A : Encoder](
|
||||||
|
dataframe: DataFrame,
|
||||||
|
transformer: Transformer,
|
||||||
|
firstResultCol: String,
|
||||||
|
otherResultCols: String*)
|
||||||
|
(checkFunction: Row => Unit): Unit = {
|
||||||
|
|
||||||
|
val columnNames = dataframe.schema.fieldNames
|
||||||
|
val stream = MemoryStream[A]
|
||||||
|
val streamDF = stream.toDS().toDF(columnNames: _*)
|
||||||
|
|
||||||
|
val data = dataframe.as[A].collect()
|
||||||
|
|
||||||
|
val streamOutput = transformer.transform(streamDF)
|
||||||
|
.select(firstResultCol, otherResultCols: _*)
|
||||||
|
testStream(streamOutput) (
|
||||||
|
AddData(stream, data: _*),
|
||||||
|
CheckAnswer(checkFunction)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
def testTransformer[A : Encoder](
|
||||||
|
dataframe: DataFrame,
|
||||||
|
transformer: Transformer,
|
||||||
|
firstResultCol: String,
|
||||||
|
otherResultCols: String*)
|
||||||
|
(checkFunction: Row => Unit): Unit = {
|
||||||
|
testTransformerOnStreamData(dataframe, transformer, firstResultCol,
|
||||||
|
otherResultCols: _*)(checkFunction)
|
||||||
|
|
||||||
|
val dfOutput = transformer.transform(dataframe)
|
||||||
|
dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach { row =>
|
||||||
|
checkFunction(row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.util
|
||||||
|
|
||||||
|
import org.apache.spark.ml.{PipelineModel, Transformer}
|
||||||
|
import org.apache.spark.ml.feature.StringIndexer
|
||||||
|
import org.apache.spark.sql.Row
|
||||||
|
|
||||||
|
class MLTestSuite extends MLTest {
|
||||||
|
|
||||||
|
import testImplicits._
|
||||||
|
|
||||||
|
test("test transformer on stream data") {
|
||||||
|
|
||||||
|
val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e"), (5, "f"))
|
||||||
|
.toDF("id", "label")
|
||||||
|
val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
|
||||||
|
.setInputCol("label").setOutputCol("indexed")
|
||||||
|
val indexerModel = indexer.fit(data)
|
||||||
|
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
|
||||||
|
case Row(id: Int, indexed: Double) =>
|
||||||
|
assert(id === indexed.toInt)
|
||||||
|
}
|
||||||
|
|
||||||
|
intercept[Exception] {
|
||||||
|
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
|
||||||
|
case Row(id: Int, indexed: Double) =>
|
||||||
|
assert(id != indexed.toInt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -133,6 +133,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
||||||
}
|
}
|
||||||
|
|
||||||
def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false)
|
def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false)
|
||||||
|
|
||||||
|
def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
|
||||||
|
CheckAnswerRowsByFunc(checkFunction, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -154,6 +157,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
||||||
}
|
}
|
||||||
|
|
||||||
def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)
|
def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)
|
||||||
|
|
||||||
|
def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
|
||||||
|
CheckAnswerRowsByFunc(checkFunction, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean)
|
case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean)
|
||||||
|
@ -162,6 +168,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
||||||
private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
|
private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean)
|
||||||
|
extends StreamAction with StreamMustBeRunning {
|
||||||
|
override def toString: String = s"$operatorName: ${checkFunction.toString()}"
|
||||||
|
private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc"
|
||||||
|
}
|
||||||
|
|
||||||
/** Stops the stream. It must currently be running. */
|
/** Stops the stream. It must currently be running. */
|
||||||
case object StopStream extends StreamAction with StreamMustBeRunning
|
case object StopStream extends StreamAction with StreamMustBeRunning
|
||||||
|
|
||||||
|
@ -352,6 +364,29 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
||||||
""".stripMargin)
|
""".stripMargin)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = {
|
||||||
|
verify(currentStream != null, "stream not running")
|
||||||
|
// Get the map of source index to the current source objects
|
||||||
|
val indexToSource = currentStream
|
||||||
|
.logicalPlan
|
||||||
|
.collect { case StreamingExecutionRelation(s, _) => s }
|
||||||
|
.zipWithIndex
|
||||||
|
.map(_.swap)
|
||||||
|
.toMap
|
||||||
|
|
||||||
|
// Block until all data added has been processed for all the source
|
||||||
|
awaiting.foreach { case (sourceIndex, offset) =>
|
||||||
|
failAfter(streamingTimeout) {
|
||||||
|
currentStream.awaitOffset(indexToSource(sourceIndex), offset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try if (lastOnly) sink.latestBatchData else sink.allData catch {
|
||||||
|
case e: Exception =>
|
||||||
|
failTest("Exception while getting data from sink", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var manualClockExpectedTime = -1L
|
var manualClockExpectedTime = -1L
|
||||||
val defaultCheckpointLocation =
|
val defaultCheckpointLocation =
|
||||||
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
|
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
|
||||||
|
@ -552,30 +587,20 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
||||||
e.runAction()
|
e.runAction()
|
||||||
|
|
||||||
case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
|
case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
|
||||||
verify(currentStream != null, "stream not running")
|
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
|
||||||
// Get the map of source index to the current source objects
|
|
||||||
val indexToSource = currentStream
|
|
||||||
.logicalPlan
|
|
||||||
.collect { case StreamingExecutionRelation(s, _) => s }
|
|
||||||
.zipWithIndex
|
|
||||||
.map(_.swap)
|
|
||||||
.toMap
|
|
||||||
|
|
||||||
// Block until all data added has been processed for all the source
|
|
||||||
awaiting.foreach { case (sourceIndex, offset) =>
|
|
||||||
failAfter(streamingTimeout) {
|
|
||||||
currentStream.awaitOffset(indexToSource(sourceIndex), offset)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val sparkAnswer = try if (lastOnly) sink.latestBatchData else sink.allData catch {
|
|
||||||
case e: Exception =>
|
|
||||||
failTest("Exception while getting data from sink", e)
|
|
||||||
}
|
|
||||||
|
|
||||||
QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
|
QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
|
||||||
error => failTest(error)
|
error => failTest(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case CheckAnswerRowsByFunc(checkFunction, lastOnly) =>
|
||||||
|
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
|
||||||
|
sparkAnswer.foreach { row =>
|
||||||
|
try {
|
||||||
|
checkFunction(row)
|
||||||
|
} catch {
|
||||||
|
case e: Throwable => failTest(e.toString)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
pos += 1
|
pos += 1
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf
|
||||||
/**
|
/**
|
||||||
* A special `SparkSession` prepared for testing.
|
* A special `SparkSession` prepared for testing.
|
||||||
*/
|
*/
|
||||||
private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
|
private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
|
||||||
def this(sparkConf: SparkConf) {
|
def this(sparkConf: SparkConf) {
|
||||||
this(new SparkContext("local[2]", "test-sql-context",
|
this(new SparkContext("local[2]", "test-sql-context",
|
||||||
sparkConf.set("spark.sql.testkey", "true")))
|
sparkConf.set("spark.sql.testkey", "true")))
|
||||||
|
|
Loading…
Reference in a new issue