[SPARK-22644][ML][TEST] Make ML testsuite support StructuredStreaming test
## What changes were proposed in this pull request? We need to add some helper code to make testing ML transformers & models easier with streaming data. These tests might help us catch any remaining issues and we could encourage future PRs to use these tests to prevent new Models & Transformers from having issues. I add a `MLTest` trait which extends `StreamTest` trait, and override `createSparkSession`. So ML testsuite can only extend `MLTest`, to use both ML & Stream test util functions. I only modify one testcase in `LinearRegressionSuite`, for first pass review. Link to #19746 ## How was this patch tested? `MLTestSuite` added. Author: WeichenXu <weichen.xu@databricks.com> Closes #19843 from WeichenXu123/ml_stream_test_helper.
This commit is contained in:
parent
c7d0148615
commit
0e36ba6212
|
@ -60,6 +60,20 @@
|
|||
<artifactId>spark-sql_${scala.binary.version}</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sql_${scala.binary.version}</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-graphx_${scala.binary.version}</artifactId>
|
||||
|
|
|
@ -24,13 +24,12 @@ import org.apache.spark.ml.feature.Instance
|
|||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
|
||||
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
|
||||
class LinearRegressionSuite
|
||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||
class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
|
@ -233,7 +232,8 @@ class LinearRegressionSuite
|
|||
assert(model2.intercept ~== interceptR relTol 1E-3)
|
||||
assert(model2.coefficients ~= coefficientsR relTol 1E-3)
|
||||
|
||||
model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach {
|
||||
testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
|
||||
"features", "prediction") {
|
||||
case Row(features: DenseVector, prediction1: Double) =>
|
||||
val prediction2 =
|
||||
features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
|
||||
|
|
91
mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
Normal file
91
mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
Normal file
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.util
|
||||
|
||||
import java.io.File
|
||||
|
||||
import org.scalatest.Suite
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.{PipelineModel, Transformer}
|
||||
import org.apache.spark.sql.{DataFrame, Encoder, Row}
|
||||
import org.apache.spark.sql.execution.streaming.MemoryStream
|
||||
import org.apache.spark.sql.streaming.StreamTest
|
||||
import org.apache.spark.sql.test.TestSparkSession
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
trait MLTest extends StreamTest with TempDirectory { self: Suite =>
|
||||
|
||||
@transient var sc: SparkContext = _
|
||||
@transient var checkpointDir: String = _
|
||||
|
||||
protected override def createSparkSession: TestSparkSession = {
|
||||
new TestSparkSession(new SparkContext("local[2]", "MLlibUnitTest", sparkConf))
|
||||
}
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
sc = spark.sparkContext
|
||||
checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
|
||||
sc.setCheckpointDir(checkpointDir)
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
try {
|
||||
Utils.deleteRecursively(new File(checkpointDir))
|
||||
} finally {
|
||||
super.afterAll()
|
||||
}
|
||||
}
|
||||
|
||||
def testTransformerOnStreamData[A : Encoder](
|
||||
dataframe: DataFrame,
|
||||
transformer: Transformer,
|
||||
firstResultCol: String,
|
||||
otherResultCols: String*)
|
||||
(checkFunction: Row => Unit): Unit = {
|
||||
|
||||
val columnNames = dataframe.schema.fieldNames
|
||||
val stream = MemoryStream[A]
|
||||
val streamDF = stream.toDS().toDF(columnNames: _*)
|
||||
|
||||
val data = dataframe.as[A].collect()
|
||||
|
||||
val streamOutput = transformer.transform(streamDF)
|
||||
.select(firstResultCol, otherResultCols: _*)
|
||||
testStream(streamOutput) (
|
||||
AddData(stream, data: _*),
|
||||
CheckAnswer(checkFunction)
|
||||
)
|
||||
}
|
||||
|
||||
def testTransformer[A : Encoder](
|
||||
dataframe: DataFrame,
|
||||
transformer: Transformer,
|
||||
firstResultCol: String,
|
||||
otherResultCols: String*)
|
||||
(checkFunction: Row => Unit): Unit = {
|
||||
testTransformerOnStreamData(dataframe, transformer, firstResultCol,
|
||||
otherResultCols: _*)(checkFunction)
|
||||
|
||||
val dfOutput = transformer.transform(dataframe)
|
||||
dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach { row =>
|
||||
checkFunction(row)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.util
|
||||
|
||||
import org.apache.spark.ml.{PipelineModel, Transformer}
|
||||
import org.apache.spark.ml.feature.StringIndexer
|
||||
import org.apache.spark.sql.Row
|
||||
|
||||
class MLTestSuite extends MLTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
test("test transformer on stream data") {
|
||||
|
||||
val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e"), (5, "f"))
|
||||
.toDF("id", "label")
|
||||
val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
|
||||
.setInputCol("label").setOutputCol("indexed")
|
||||
val indexerModel = indexer.fit(data)
|
||||
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
|
||||
case Row(id: Int, indexed: Double) =>
|
||||
assert(id === indexed.toInt)
|
||||
}
|
||||
|
||||
intercept[Exception] {
|
||||
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
|
||||
case Row(id: Int, indexed: Double) =>
|
||||
assert(id != indexed.toInt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -133,6 +133,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
}
|
||||
|
||||
def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false)
|
||||
|
||||
def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
|
||||
CheckAnswerRowsByFunc(checkFunction, false)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -154,6 +157,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
}
|
||||
|
||||
def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)
|
||||
|
||||
def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
|
||||
CheckAnswerRowsByFunc(checkFunction, true)
|
||||
}
|
||||
|
||||
case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean)
|
||||
|
@ -162,6 +168,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
|
||||
}
|
||||
|
||||
case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean)
|
||||
extends StreamAction with StreamMustBeRunning {
|
||||
override def toString: String = s"$operatorName: ${checkFunction.toString()}"
|
||||
private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc"
|
||||
}
|
||||
|
||||
/** Stops the stream. It must currently be running. */
|
||||
case object StopStream extends StreamAction with StreamMustBeRunning
|
||||
|
||||
|
@ -352,6 +364,29 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
""".stripMargin)
|
||||
}
|
||||
|
||||
def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = {
|
||||
verify(currentStream != null, "stream not running")
|
||||
// Get the map of source index to the current source objects
|
||||
val indexToSource = currentStream
|
||||
.logicalPlan
|
||||
.collect { case StreamingExecutionRelation(s, _) => s }
|
||||
.zipWithIndex
|
||||
.map(_.swap)
|
||||
.toMap
|
||||
|
||||
// Block until all data added has been processed for all the source
|
||||
awaiting.foreach { case (sourceIndex, offset) =>
|
||||
failAfter(streamingTimeout) {
|
||||
currentStream.awaitOffset(indexToSource(sourceIndex), offset)
|
||||
}
|
||||
}
|
||||
|
||||
try if (lastOnly) sink.latestBatchData else sink.allData catch {
|
||||
case e: Exception =>
|
||||
failTest("Exception while getting data from sink", e)
|
||||
}
|
||||
}
|
||||
|
||||
var manualClockExpectedTime = -1L
|
||||
val defaultCheckpointLocation =
|
||||
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
|
||||
|
@ -552,30 +587,20 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
e.runAction()
|
||||
|
||||
case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
|
||||
verify(currentStream != null, "stream not running")
|
||||
// Get the map of source index to the current source objects
|
||||
val indexToSource = currentStream
|
||||
.logicalPlan
|
||||
.collect { case StreamingExecutionRelation(s, _) => s }
|
||||
.zipWithIndex
|
||||
.map(_.swap)
|
||||
.toMap
|
||||
|
||||
// Block until all data added has been processed for all the source
|
||||
awaiting.foreach { case (sourceIndex, offset) =>
|
||||
failAfter(streamingTimeout) {
|
||||
currentStream.awaitOffset(indexToSource(sourceIndex), offset)
|
||||
}
|
||||
}
|
||||
|
||||
val sparkAnswer = try if (lastOnly) sink.latestBatchData else sink.allData catch {
|
||||
case e: Exception =>
|
||||
failTest("Exception while getting data from sink", e)
|
||||
}
|
||||
|
||||
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
|
||||
QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
|
||||
error => failTest(error)
|
||||
}
|
||||
|
||||
case CheckAnswerRowsByFunc(checkFunction, lastOnly) =>
|
||||
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
|
||||
sparkAnswer.foreach { row =>
|
||||
try {
|
||||
checkFunction(row)
|
||||
} catch {
|
||||
case e: Throwable => failTest(e.toString)
|
||||
}
|
||||
}
|
||||
}
|
||||
pos += 1
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf
|
|||
/**
|
||||
* A special `SparkSession` prepared for testing.
|
||||
*/
|
||||
private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
|
||||
private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
|
||||
def this(sparkConf: SparkConf) {
|
||||
this(new SparkContext("local[2]", "test-sql-context",
|
||||
sparkConf.set("spark.sql.testkey", "true")))
|
||||
|
|
Loading…
Reference in a new issue