[SPARK-22734][ML][PYSPARK] Added Python API for VectorSizeHint.
(Please fill in changes proposed in this fix) Python API for VectorSizeHint Transformer. (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) doc-tests. Author: Bago Amirbekian <bago@databricks.com> Closes #20112 from MrBago/vectorSizeHint-PythonAPI.
This commit is contained in:
parent
30fcdc0380
commit
816963043a
|
@ -35,6 +35,7 @@ import org.apache.spark.sql.types.StructType
|
|||
* VectorAssembler needs size information for its input columns and cannot be used on streaming
|
||||
* dataframes without this metadata.
|
||||
*
|
||||
* Note: VectorSizeHint modifies `inputCol` to include size metadata and does not have an outputCol.
|
||||
*/
|
||||
@Experimental
|
||||
@Since("2.3.0")
|
||||
|
|
|
@ -57,6 +57,7 @@ __all__ = ['Binarizer',
|
|||
'Tokenizer',
|
||||
'VectorAssembler',
|
||||
'VectorIndexer', 'VectorIndexerModel',
|
||||
'VectorSizeHint',
|
||||
'VectorSlicer',
|
||||
'Word2Vec', 'Word2VecModel']
|
||||
|
||||
|
@ -3466,6 +3467,84 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
|
|||
return self._call_java("selectedFeatures")
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReadable,
|
||||
JavaMLWritable):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
A feature transformer that adds size information to the metadata of a vector column.
|
||||
VectorAssembler needs size information for its input columns and cannot be used on streaming
|
||||
dataframes without this metadata.
|
||||
|
||||
.. note:: VectorSizeHint modifies `inputCol` to include size metadata and does not have an
|
||||
outputCol.
|
||||
|
||||
>>> from pyspark.ml.linalg import Vectors
|
||||
>>> from pyspark.ml import Pipeline, PipelineModel
|
||||
>>> data = [(Vectors.dense([1., 2., 3.]), 4.)]
|
||||
>>> df = spark.createDataFrame(data, ["vector", "float"])
|
||||
>>>
|
||||
>>> sizeHint = VectorSizeHint(inputCol="vector", size=3, handleInvalid="skip")
|
||||
>>> vecAssembler = VectorAssembler(inputCols=["vector", "float"], outputCol="assembled")
|
||||
>>> pipeline = Pipeline(stages=[sizeHint, vecAssembler])
|
||||
>>>
|
||||
>>> pipelineModel = pipeline.fit(df)
|
||||
>>> pipelineModel.transform(df).head().assembled
|
||||
DenseVector([1.0, 2.0, 3.0, 4.0])
|
||||
>>> vectorSizeHintPath = temp_path + "/vector-size-hint-pipeline"
|
||||
>>> pipelineModel.save(vectorSizeHintPath)
|
||||
>>> loadedPipeline = PipelineModel.load(vectorSizeHintPath)
|
||||
>>> loaded = loadedPipeline.transform(df).head().assembled
|
||||
>>> expected = pipelineModel.transform(df).head().assembled
|
||||
>>> loaded == expected
|
||||
True
|
||||
|
||||
.. versionadded:: 2.3.0
|
||||
"""
|
||||
|
||||
size = Param(Params._dummy(), "size", "Size of vectors in column.",
|
||||
typeConverter=TypeConverters.toInt)
|
||||
|
||||
handleInvalid = Param(Params._dummy(), "handleInvalid",
|
||||
"How to handle invalid vectors in inputCol. Invalid vectors include "
|
||||
"nulls and vectors with the wrong size. The options are `skip` (filter "
|
||||
"out rows with invalid vectors), `error` (throw an error) and "
|
||||
"`optimistic` (do not check the vector size, and keep all rows). "
|
||||
"`error` by default.",
|
||||
TypeConverters.toString)
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, inputCol=None, size=None, handleInvalid="error"):
|
||||
"""
|
||||
__init__(self, inputCol=None, size=None, handleInvalid="error")
|
||||
"""
|
||||
super(VectorSizeHint, self).__init__()
|
||||
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSizeHint", self.uid)
|
||||
self._setDefault(handleInvalid="error")
|
||||
self.setParams(**self._input_kwargs)
|
||||
|
||||
@keyword_only
|
||||
@since("2.3.0")
|
||||
def setParams(self, inputCol=None, size=None, handleInvalid="error"):
|
||||
"""
|
||||
setParams(self, inputCol=None, size=None, handleInvalid="error")
|
||||
Sets params for this VectorSizeHint.
|
||||
"""
|
||||
kwargs = self._input_kwargs
|
||||
return self._set(**kwargs)
|
||||
|
||||
@since("2.3.0")
|
||||
def getSize(self):
|
||||
""" Gets size param, the size of vectors in `inputCol`."""
|
||||
self.getOrDefault(self.size)
|
||||
|
||||
@since("2.3.0")
|
||||
def setSize(self, value):
|
||||
""" Sets size param, the size of vectors in `inputCol`."""
|
||||
self._set(size=value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
import tempfile
|
||||
|
|
Loading…
Reference in a new issue