diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 82066726a0..b02aea92b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -103,7 +103,10 @@ class Pipeline @Since("1.4.0") ( /** @group setParam */ @Since("1.2.0") - def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } + def setStages(value: Array[_ <: PipelineStage]): this.type = { + set(stages, value.asInstanceOf[Array[PipelineStage]]) + this + } // Below, we clone stages so that modifications to the list of stages will not change // the Param value in the Pipeline. diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index a8c4ac6d05..1de638f245 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} -import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -201,6 +201,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul pipeline.fit(df) } } + + test("Pipeline.setStages should handle Java Arrays being non-covariant") { + val stages0 = Array(new UnWritableStage("b")) + val stages1 = Array(new WritableStage("a")) + val steps = stages0 ++ stages1 + val p = new Pipeline().setStages(steps) + } }