[SPARK-14671][ML] Pipeline setStages should handle subclasses of PipelineStage
## What changes were proposed in this pull request? Pipeline.setStages failed for some code examples which worked in 1.5 but fail in 1.6. This tends to occur when using a mix of transformers from ml.feature. It is because Java Arrays are non-covariant and the addition of MLWritable to some transformers means the stages0/1 arrays above are not of type Array[PipelineStage]. This PR modifies the following to accept subclasses of PipelineStage: * Pipeline.setStages() * Params.w() ## How was this patch tested? Unit test which fails to compile before this fix. Author: Joseph K. Bradley <joseph@databricks.com> Closes #12430 from jkbradley/pipeline-setstages.
This commit is contained in:
parent
6466d6c8a4
commit
f5ebb18c45
|
@ -103,7 +103,10 @@ class Pipeline @Since("1.4.0") (
|
|||
|
||||
/** @group setParam */
|
||||
@Since("1.2.0")
|
||||
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
|
||||
def setStages(value: Array[_ <: PipelineStage]): this.type = {
|
||||
set(stages, value.asInstanceOf[Array[PipelineStage]])
|
||||
this
|
||||
}
|
||||
|
||||
// Below, we clone stages so that modifications to the list of stages will not change
|
||||
// the Param value in the Pipeline.
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.scalatest.mock.MockitoSugar.mock
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.Pipeline.SharedReadWrite
|
||||
import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler}
|
||||
import org.apache.spark.ml.param.{IntParam, ParamMap}
|
||||
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair}
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
|
@ -201,6 +201,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
|||
pipeline.fit(df)
|
||||
}
|
||||
}
|
||||
|
||||
test("Pipeline.setStages should handle Java Arrays being non-covariant") {
|
||||
val stages0 = Array(new UnWritableStage("b"))
|
||||
val stages1 = Array(new WritableStage("a"))
|
||||
val steps = stages0 ++ stages1
|
||||
val p = new Pipeline().setStages(steps)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue