[SPARK-20899][PYSPARK] PySpark supports stringIndexerOrderType in RFormula
## What changes were proposed in this pull request? PySpark supports stringIndexerOrderType in RFormula as in #17967. ## How was this patch tested? docstring test Author: actuaryzhang <actuaryzhang10@gmail.com> Closes #18122 from actuaryzhang/PythonRFormula.
This commit is contained in:
parent
35b644bd03
commit
ff5676b01f
|
@ -3043,26 +3043,35 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
|
|||
"Force to index label whether it is numeric or string",
|
||||
typeConverter=TypeConverters.toBoolean)
|
||||
|
||||
stringIndexerOrderType = Param(Params._dummy(), "stringIndexerOrderType",
|
||||
"How to order categories of a string feature column used by " +
|
||||
"StringIndexer. The last category after ordering is dropped " +
|
||||
"when encoding strings. Supported options: frequencyDesc, " +
|
||||
"frequencyAsc, alphabetDesc, alphabetAsc. The default value " +
|
||||
"is frequencyDesc. When the ordering is set to alphabetDesc, " +
|
||||
"RFormula drops the same category as R when encoding strings.",
|
||||
typeConverter=TypeConverters.toString)
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, formula=None, featuresCol="features", labelCol="label",
|
||||
forceIndexLabel=False):
|
||||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
|
||||
"""
|
||||
__init__(self, formula=None, featuresCol="features", labelCol="label", \
|
||||
forceIndexLabel=False)
|
||||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
|
||||
"""
|
||||
super(RFormula, self).__init__()
|
||||
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
|
||||
self._setDefault(forceIndexLabel=False)
|
||||
self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
|
||||
kwargs = self._input_kwargs
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@keyword_only
|
||||
@since("1.5.0")
|
||||
def setParams(self, formula=None, featuresCol="features", labelCol="label",
|
||||
forceIndexLabel=False):
|
||||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
|
||||
"""
|
||||
setParams(self, formula=None, featuresCol="features", labelCol="label", \
|
||||
forceIndexLabel=False)
|
||||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
|
||||
Sets params for RFormula.
|
||||
"""
|
||||
kwargs = self._input_kwargs
|
||||
|
@ -3096,6 +3105,20 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
|
|||
"""
|
||||
return self.getOrDefault(self.forceIndexLabel)
|
||||
|
||||
@since("2.3.0")
|
||||
def setStringIndexerOrderType(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`stringIndexerOrderType`.
|
||||
"""
|
||||
return self._set(stringIndexerOrderType=value)
|
||||
|
||||
@since("2.3.0")
|
||||
def getStringIndexerOrderType(self):
|
||||
"""
|
||||
Gets the value of :py:attr:`stringIndexerOrderType` or its default value 'frequencyDesc'.
|
||||
"""
|
||||
return self.getOrDefault(self.stringIndexerOrderType)
|
||||
|
||||
def _create_model(self, java_model):
|
||||
return RFormulaModel(java_model)
|
||||
|
||||
|
|
|
@ -538,6 +538,19 @@ class FeatureTests(SparkSessionTestCase):
|
|||
transformedDF2 = model2.transform(df)
|
||||
self.assertEqual(transformedDF2.head().label, 0.0)
|
||||
|
||||
def test_rformula_string_indexer_order_type(self):
|
||||
df = self.spark.createDataFrame([
|
||||
(1.0, 1.0, "a"),
|
||||
(0.0, 2.0, "b"),
|
||||
(1.0, 0.0, "a")], ["y", "x", "s"])
|
||||
rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc")
|
||||
self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc')
|
||||
transformedDF = rf.fit(df).transform(df)
|
||||
observed = transformedDF.select("features").collect()
|
||||
expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]]
|
||||
for i in range(0, len(expected)):
|
||||
self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
|
||||
|
||||
|
||||
class HasInducedError(Params):
|
||||
|
||||
|
|
Loading…
Reference in a new issue