diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9f6e7b6b6b..63475780a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -102,10 +102,12 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod * This is a temporary fix for the case when target labels do not exist during prediction. */ @Experimental -class StringIndexerModel private[ml] ( +class StringIndexerModel ( override val uid: String, labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) + private val labelToIndex: OpenHashMap[String, Double] = { val n = labels.length val map = new OpenHashMap[String, Double](n) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index fa918ce648..0b4c8ba71e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -30,7 +30,9 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new StringIndexer) val model = new StringIndexerModel("indexer", Array("a", "b")) + val modelWithoutUid = new StringIndexerModel(Array("a", "b")) ParamsSuite.checkParams(model) + ParamsSuite.checkParams(modelWithoutUid) } test("StringIndexer") {