[SPARK-7388] [SPARK-7383] wrapper for VectorAssembler in Python
The wrapper required the implementation of the `ArrayParam`, because `Array[T]` is hard to obtain from Python. `ArrayParam` has an extra function called `wCast` which is an internal function to obtain `Array[T]` from `Seq[T]`
Author: Burak Yavuz <brkyvz@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>
Closes #5930 from brkyvz/ml-feat and squashes the following commits:
73e745f [Burak Yavuz] Merge pull request #3 from mengxr/SPARK-7388
c221db9 [Xiangrui Meng] overload StringArrayParam.w
c81072d [Burak Yavuz] addressed comments
99c2ebf [Burak Yavuz] add to python_shared_params
39ecb07 [Burak Yavuz] fix scalastyle
7f7ea2a [Burak Yavuz] [SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python
(cherry picked from commit 9e2ffb1328
)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
parent
84ee348bce
commit
6b9737a830
|
@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
|
|||
|
||||
/**
|
||||
* :: AlphaComponent ::
|
||||
* A feature transformer than merge multiple columns into a vector column.
|
||||
* A feature transformer that merges multiple columns into a vector column.
|
||||
*/
|
||||
@AlphaComponent
|
||||
class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.NoSuchElementException
|
|||
|
||||
import scala.annotation.varargs
|
||||
import scala.collection.mutable
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.annotation.AlphaComponent
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
|
@ -218,6 +219,19 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
|
|||
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
|
||||
}
|
||||
|
||||
/** Specialized version of [[Param[Array[T]]]] for Java. */
|
||||
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
|
||||
extends Param[Array[String]](parent, name, doc, isValid) {
|
||||
|
||||
def this(parent: Params, name: String, doc: String) =
|
||||
this(parent, name, doc, ParamValidators.alwaysTrue)
|
||||
|
||||
override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
|
||||
|
||||
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
|
||||
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
|
||||
}
|
||||
|
||||
/**
|
||||
* A param amd its value.
|
||||
*/
|
||||
|
@ -310,9 +324,7 @@ trait Params extends Identifiable with Serializable {
|
|||
* Sets a parameter in the embedded param map.
|
||||
*/
|
||||
protected final def set[T](param: Param[T], value: T): this.type = {
|
||||
shouldOwn(param)
|
||||
paramMap.put(param.asInstanceOf[Param[Any]], value)
|
||||
this
|
||||
set(param -> value)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -322,6 +334,15 @@ trait Params extends Identifiable with Serializable {
|
|||
set(getParam(param), value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a parameter in the embedded param map.
|
||||
*/
|
||||
protected final def set(paramPair: ParamPair[_]): this.type = {
|
||||
shouldOwn(paramPair.param)
|
||||
paramMap.put(paramPair)
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Optionally returns the user-supplied value of a param.
|
||||
*/
|
||||
|
|
|
@ -85,6 +85,7 @@ private[shared] object SharedParamsCodeGen {
|
|||
case _ if c == classOf[Float] => "FloatParam"
|
||||
case _ if c == classOf[Double] => "DoubleParam"
|
||||
case _ if c == classOf[Boolean] => "BooleanParam"
|
||||
case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
|
||||
case _ => s"Param[${getTypeString(c)}]"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
|
|||
* Param for input column names.
|
||||
* @group param
|
||||
*/
|
||||
final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names")
|
||||
final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names")
|
||||
|
||||
/** @group getParam */
|
||||
final def getInputCols: Array[String] = $(inputCols)
|
||||
|
|
|
@ -16,12 +16,12 @@
|
|||
#
|
||||
|
||||
from pyspark.rdd import ignore_unicode_prefix
|
||||
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
|
||||
from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures
|
||||
from pyspark.ml.util import keyword_only
|
||||
from pyspark.ml.wrapper import JavaTransformer
|
||||
from pyspark.mllib.common import inherit_doc
|
||||
|
||||
__all__ = ['Tokenizer', 'HashingTF']
|
||||
__all__ = ['Tokenizer', 'HashingTF', 'VectorAssembler']
|
||||
|
||||
|
||||
@inherit_doc
|
||||
|
@ -112,6 +112,45 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
|
|||
return self._set(**kwargs)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
|
||||
"""
|
||||
A feature transformer that merges multiple columns into a vector column.
|
||||
|
||||
>>> from pyspark.sql import Row
|
||||
>>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
|
||||
>>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
|
||||
>>> vecAssembler.transform(df).head().features
|
||||
SparseVector(3, {0: 1.0, 2: 3.0})
|
||||
>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
|
||||
SparseVector(3, {0: 1.0, 2: 3.0})
|
||||
>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
|
||||
>>> vecAssembler.transform(df, params).head().vector
|
||||
SparseVector(2, {1: 1.0})
|
||||
"""
|
||||
|
||||
_java_class = "org.apache.spark.ml.feature.VectorAssembler"
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, inputCols=None, outputCol=None):
|
||||
"""
|
||||
__init__(self, inputCols=None, outputCol=None)
|
||||
"""
|
||||
super(VectorAssembler, self).__init__()
|
||||
self._setDefault()
|
||||
kwargs = self.__init__._input_kwargs
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@keyword_only
|
||||
def setParams(self, inputCols=None, outputCol=None):
|
||||
"""
|
||||
setParams(self, inputCols=None, outputCol=None)
|
||||
Sets params for this VectorAssembler.
|
||||
"""
|
||||
kwargs = self.setParams._input_kwargs
|
||||
return self._set(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
from pyspark.context import SparkContext
|
||||
|
|
|
@ -95,6 +95,7 @@ if __name__ == "__main__":
|
|||
("predictionCol", "prediction column name", "'prediction'"),
|
||||
("rawPredictionCol", "raw prediction column name", "'rawPrediction'"),
|
||||
("inputCol", "input column name", None),
|
||||
("inputCols", "input column names", None),
|
||||
("outputCol", "output column name", None),
|
||||
("numFeatures", "number of features", None)]
|
||||
code = []
|
||||
|
|
|
@ -223,6 +223,35 @@ class HasInputCol(Params):
|
|||
return self.getOrDefault(self.inputCol)
|
||||
|
||||
|
||||
class HasInputCols(Params):
|
||||
"""
|
||||
Mixin for param inputCols: input column names.
|
||||
"""
|
||||
|
||||
# a placeholder to make it appear in the generated doc
|
||||
inputCols = Param(Params._dummy(), "inputCols", "input column names")
|
||||
|
||||
def __init__(self):
|
||||
super(HasInputCols, self).__init__()
|
||||
#: param for input column names
|
||||
self.inputCols = Param(self, "inputCols", "input column names")
|
||||
if None is not None:
|
||||
self._setDefault(inputCols=None)
|
||||
|
||||
def setInputCols(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`inputCols`.
|
||||
"""
|
||||
self.paramMap[self.inputCols] = value
|
||||
return self
|
||||
|
||||
def getInputCols(self):
|
||||
"""
|
||||
Gets the value of inputCols or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.inputCols)
|
||||
|
||||
|
||||
class HasOutputCol(Params):
|
||||
"""
|
||||
Mixin for param outputCol: output column name.
|
||||
|
|
|
@ -67,7 +67,9 @@ class JavaWrapper(Params):
|
|||
paramMap = self.extractParamMap(params)
|
||||
for param in self.params:
|
||||
if param in paramMap:
|
||||
java_obj.set(param.name, paramMap[param])
|
||||
value = paramMap[param]
|
||||
java_param = java_obj.getParam(param.name)
|
||||
java_obj.set(java_param.w(value))
|
||||
|
||||
def _empty_java_param_map(self):
|
||||
"""
|
||||
|
@ -79,7 +81,8 @@ class JavaWrapper(Params):
|
|||
paramMap = self._empty_java_param_map()
|
||||
for param, value in params.items():
|
||||
if param.parent is self:
|
||||
paramMap.put(java_obj.getParam(param.name), value)
|
||||
java_param = java_obj.getParam(param.name)
|
||||
paramMap.put(java_param.w(value))
|
||||
return paramMap
|
||||
|
||||
|
||||
|
@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper):
|
|||
|
||||
def transform(self, dataset, params={}):
|
||||
java_obj = self._java_obj()
|
||||
self._transfer_params_to_java({}, java_obj)
|
||||
java_param_map = self._create_java_param_map(params, java_obj)
|
||||
return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
|
||||
dataset.sql_ctx)
|
||||
self._transfer_params_to_java(params, java_obj)
|
||||
return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
|
|
Loading…
Reference in a new issue