[SPARK-22796][PYTHON][ML] Add multiple columns support to PySpark QuantileDiscretizer
### What changes were proposed in this pull request? Add multiple columns support to PySpark QuantileDiscretizer ### Why are the changes needed? Multiple columns support for QuantileDiscretizer was in scala side a while ago. We need to add multiple columns support to python too. ### Does this PR introduce any user-facing change? Yes. New Python is added ### How was this patch tested? Add doctest Closes #25812 from huaxingao/spark-22796. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Liang-Chi Hsieh <liangchi@uber.com>
This commit is contained in:
parent
b4b2e958ce
commit
db9e0fda6b
|
@ -1959,17 +1959,22 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
|
|||
|
||||
|
||||
@inherit_doc
|
||||
class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
|
||||
JavaMLReadable, JavaMLWritable):
|
||||
class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
|
||||
HasHandleInvalid, JavaMLReadable, JavaMLWritable):
|
||||
"""
|
||||
`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
|
||||
categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter.
|
||||
It is possible that the number of buckets used will be less than this value, for example, if
|
||||
there are too few distinct values of the input to create enough distinct quantiles.
|
||||
:py:class:`QuantileDiscretizer` takes a column with continuous features and outputs a column
|
||||
with binned categorical features. The number of bins can be set using the :py:attr:`numBuckets`
|
||||
parameter. It is possible that the number of buckets used will be less than this value, for
|
||||
example, if there are too few distinct values of the input to create enough distinct quantiles.
|
||||
Since 3.0.0, :py:class:`QuantileDiscretizer` can map multiple columns at once by setting the
|
||||
:py:attr:`inputCols` parameter. If both of the :py:attr:`inputCol` and :py:attr:`inputCols`
|
||||
parameters are set, an Exception will be thrown. To specify the number of buckets for each
|
||||
column, the :py:attr:`numBucketsArray` parameter can be set, or if the number of buckets
|
||||
should be the same across columns, :py:attr:`numBuckets` can be set as a convenience.
|
||||
|
||||
NaN handling: Note also that
|
||||
QuantileDiscretizer will raise an error when it finds NaN values in the dataset, but the user
|
||||
can also choose to either keep or remove NaN values within the dataset by setting
|
||||
:py:class:`QuantileDiscretizer` will raise an error when it finds NaN values in the dataset,
|
||||
but the user can also choose to either keep or remove NaN values within the dataset by setting
|
||||
:py:attr:`handleInvalid` parameter. If the user chooses to keep NaN values, they will be
|
||||
handled specially and placed into their own bucket, for example, if 4 buckets are used, then
|
||||
non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4].
|
||||
|
@ -1981,29 +1986,61 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInv
|
|||
The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values.
|
||||
|
||||
>>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
|
||||
>>> df = spark.createDataFrame(values, ["values"])
|
||||
>>> qds = QuantileDiscretizer(numBuckets=2,
|
||||
>>> df1 = spark.createDataFrame(values, ["values"])
|
||||
>>> qds1 = QuantileDiscretizer(numBuckets=2,
|
||||
... inputCol="values", outputCol="buckets", relativeError=0.01, handleInvalid="error")
|
||||
>>> qds.getRelativeError()
|
||||
>>> qds1.getRelativeError()
|
||||
0.01
|
||||
>>> bucketizer = qds.fit(df)
|
||||
>>> qds.setHandleInvalid("keep").fit(df).transform(df).count()
|
||||
>>> bucketizer = qds1.fit(df1)
|
||||
>>> qds1.setHandleInvalid("keep").fit(df1).transform(df1).count()
|
||||
6
|
||||
>>> qds.setHandleInvalid("skip").fit(df).transform(df).count()
|
||||
>>> qds1.setHandleInvalid("skip").fit(df1).transform(df1).count()
|
||||
4
|
||||
>>> splits = bucketizer.getSplits()
|
||||
>>> splits[0]
|
||||
-inf
|
||||
>>> print("%2.1f" % round(splits[1], 1))
|
||||
0.4
|
||||
>>> bucketed = bucketizer.transform(df).head()
|
||||
>>> bucketed = bucketizer.transform(df1).head()
|
||||
>>> bucketed.buckets
|
||||
0.0
|
||||
>>> quantileDiscretizerPath = temp_path + "/quantile-discretizer"
|
||||
>>> qds.save(quantileDiscretizerPath)
|
||||
>>> qds1.save(quantileDiscretizerPath)
|
||||
>>> loadedQds = QuantileDiscretizer.load(quantileDiscretizerPath)
|
||||
>>> loadedQds.getNumBuckets() == qds.getNumBuckets()
|
||||
>>> loadedQds.getNumBuckets() == qds1.getNumBuckets()
|
||||
True
|
||||
>>> inputs = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, 1.5),
|
||||
... (float("nan"), float("nan")), (float("nan"), float("nan"))]
|
||||
>>> df2 = spark.createDataFrame(inputs, ["input1", "input2"])
|
||||
>>> qds2 = QuantileDiscretizer(relativeError=0.01, handleInvalid="error", numBuckets=2,
|
||||
... inputCols=["input1", "input2"], outputCols=["output1", "output2"])
|
||||
>>> qds2.getRelativeError()
|
||||
0.01
|
||||
>>> qds2.setHandleInvalid("keep").fit(df2).transform(df2).show()
|
||||
+------+------+-------+-------+
|
||||
|input1|input2|output1|output2|
|
||||
+------+------+-------+-------+
|
||||
| 0.1| 0.0| 0.0| 0.0|
|
||||
| 0.4| 1.0| 1.0| 1.0|
|
||||
| 1.2| 1.3| 1.0| 1.0|
|
||||
| 1.5| 1.5| 1.0| 1.0|
|
||||
| NaN| NaN| 2.0| 2.0|
|
||||
| NaN| NaN| 2.0| 2.0|
|
||||
+------+------+-------+-------+
|
||||
...
|
||||
>>> qds3 = QuantileDiscretizer(relativeError=0.01, handleInvalid="error",
|
||||
... numBucketsArray=[5, 10], inputCols=["input1", "input2"],
|
||||
... outputCols=["output1", "output2"])
|
||||
>>> qds3.setHandleInvalid("skip").fit(df2).transform(df2).show()
|
||||
+------+------+-------+-------+
|
||||
|input1|input2|output1|output2|
|
||||
+------+------+-------+-------+
|
||||
| 0.1| 0.0| 1.0| 1.0|
|
||||
| 0.4| 1.0| 2.0| 2.0|
|
||||
| 1.2| 1.3| 3.0| 3.0|
|
||||
| 1.5| 1.5| 4.0| 4.0|
|
||||
+------+------+-------+-------+
|
||||
...
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
"""
|
||||
|
@ -2021,15 +2058,26 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInv
|
|||
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
|
||||
"Options are skip (filter out rows with invalid values), " +
|
||||
"error (throw an error), or keep (keep invalid values in a special " +
|
||||
"additional bucket).",
|
||||
"additional bucket). Note that in the multiple columns " +
|
||||
"case, the invalid handling is applied to all columns. That said " +
|
||||
"for 'error' it will throw an error if any invalids are found in " +
|
||||
"any columns, for 'skip' it will skip rows with any invalids in " +
|
||||
"any columns, etc.",
|
||||
typeConverter=TypeConverters.toString)
|
||||
|
||||
numBucketsArray = Param(Params._dummy(), "numBucketsArray", "Array of number of buckets " +
|
||||
"(quantiles, or categories) into which data points are grouped. " +
|
||||
"This is for multiple columns input. If transforming multiple " +
|
||||
"columns and numBucketsArray is not set, but numBuckets is set, " +
|
||||
"then numBuckets will be applied across all columns.",
|
||||
typeConverter=TypeConverters.toListInt)
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
|
||||
handleInvalid="error"):
|
||||
handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None):
|
||||
"""
|
||||
__init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
|
||||
handleInvalid="error")
|
||||
handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None)
|
||||
"""
|
||||
super(QuantileDiscretizer, self).__init__()
|
||||
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer",
|
||||
|
@ -2041,10 +2089,10 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInv
|
|||
@keyword_only
|
||||
@since("2.0.0")
|
||||
def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
|
||||
handleInvalid="error"):
|
||||
handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None):
|
||||
"""
|
||||
setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
|
||||
handleInvalid="error")
|
||||
handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None)
|
||||
Set the params for the QuantileDiscretizer
|
||||
"""
|
||||
kwargs = self._input_kwargs
|
||||
|
@ -2064,6 +2112,20 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInv
|
|||
"""
|
||||
return self.getOrDefault(self.numBuckets)
|
||||
|
||||
@since("3.0.0")
|
||||
def setNumBucketsArray(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`numBucketsArray`.
|
||||
"""
|
||||
return self._set(numBucketsArray=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def getNumBucketsArray(self):
|
||||
"""
|
||||
Gets the value of numBucketsArray or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.numBucketsArray)
|
||||
|
||||
@since("2.0.0")
|
||||
def setRelativeError(self, value):
|
||||
"""
|
||||
|
@ -2082,10 +2144,17 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInv
|
|||
"""
|
||||
Private method to convert the java_model to a Python model.
|
||||
"""
|
||||
return Bucketizer(splits=list(java_model.getSplits()),
|
||||
inputCol=self.getInputCol(),
|
||||
outputCol=self.getOutputCol(),
|
||||
handleInvalid=self.getHandleInvalid())
|
||||
if (self.isSet(self.inputCol)):
|
||||
return Bucketizer(splits=list(java_model.getSplits()),
|
||||
inputCol=self.getInputCol(),
|
||||
outputCol=self.getOutputCol(),
|
||||
handleInvalid=self.getHandleInvalid())
|
||||
else:
|
||||
splitsArrayList = [list(x) for x in list(java_model.getSplitsArray())]
|
||||
return Bucketizer(splitsArray=splitsArrayList,
|
||||
inputCols=self.getInputCols(),
|
||||
outputCols=self.getOutputCols(),
|
||||
handleInvalid=self.getHandleInvalid())
|
||||
|
||||
|
||||
@inherit_doc
|
||||
|
|
Loading…
Reference in a new issue