diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 955bc9768c..77de1cc182 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3043,26 +3043,35 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM "Force to index label whether it is numeric or string", typeConverter=TypeConverters.toBoolean) + stringIndexerOrderType = Param(Params._dummy(), "stringIndexerOrderType", + "How to order categories of a string feature column used by " + + "StringIndexer. The last category after ordering is dropped " + + "when encoding strings. Supported options: frequencyDesc, " + + "frequencyAsc, alphabetDesc, alphabetAsc. The default value " + + "is frequencyDesc. When the ordering is set to alphabetDesc, " + + "RFormula drops the same category as R when encoding strings.", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, formula=None, featuresCol="features", labelCol="label", - forceIndexLabel=False): + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"): """ __init__(self, formula=None, featuresCol="features", labelCol="label", \ - forceIndexLabel=False) + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) - self._setDefault(forceIndexLabel=False) + self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.5.0") def setParams(self, formula=None, featuresCol="features", labelCol="label", - forceIndexLabel=False): + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"): """ setParams(self, formula=None, featuresCol="features", labelCol="label", \ - forceIndexLabel=False) + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") Sets params for RFormula. """ kwargs = self._input_kwargs @@ -3096,6 +3105,20 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM """ return self.getOrDefault(self.forceIndexLabel) + @since("2.3.0") + def setStringIndexerOrderType(self, value): + """ + Sets the value of :py:attr:`stringIndexerOrderType`. + """ + return self._set(stringIndexerOrderType=value) + + @since("2.3.0") + def getStringIndexerOrderType(self): + """ + Gets the value of :py:attr:`stringIndexerOrderType` or its default value 'frequencyDesc'. + """ + return self.getOrDefault(self.stringIndexerOrderType) + def _create_model(self, java_model): return RFormulaModel(java_model) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 0daf29d59c..17a39472e1 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -538,6 +538,19 @@ class FeatureTests(SparkSessionTestCase): transformedDF2 = model2.transform(df) self.assertEqual(transformedDF2.head().label, 0.0) + def test_rformula_string_indexer_order_type(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc") + self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc') + transformedDF = rf.fit(df).transform(df) + observed = transformedDF.select("features").collect() + expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]] + for i in range(0, len(expected)): + self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + class HasInducedError(Params):