[SPARK-14671][ML] Pipeline setStages should handle subclasses of PipelineStage

## What changes were proposed in this pull request?

Pipeline.setStages failed for some code examples which worked in 1.5 but fail in 1.6.  This tends to occur when using a mix of transformers from ml.feature. It is because Java Arrays are non-covariant and the addition of MLWritable to some transformers means the stages0/1 arrays above are not of type Array[PipelineStage].  This PR modifies the following to accept subclasses of PipelineStage:
* Pipeline.setStages()
* Params.w()

## How was this patch tested?

Unit test which fails to compile before this fix.

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #12430 from jkbradley/pipeline-setstages.
This commit is contained in:
Joseph K. Bradley 2016-04-27 16:11:12 -07:00
parent 6466d6c8a4
commit f5ebb18c45
2 changed files with 12 additions and 2 deletions

View file

@ -103,7 +103,10 @@ class Pipeline @Since("1.4.0") (
/** @group setParam */
@Since("1.2.0")
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
def setStages(value: Array[_ <: PipelineStage]): this.type = {
set(stages, value.asInstanceOf[Array[PipelineStage]])
this
}
// Below, we clone stages so that modifications to the list of stages will not change
// the Param value in the Pipeline.

View file

@ -27,7 +27,7 @@ import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite
import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler}
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@ -201,6 +201,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
pipeline.fit(df)
}
}
test("Pipeline.setStages should handle Java Arrays being non-covariant") {
val stages0 = Array(new UnWritableStage("b"))
val stages1 = Array(new WritableStage("a"))
val steps = stages0 ++ stages1
val p = new Pipeline().setStages(steps)
}
}