[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml
Otherwise, users can only use `transform` on the models. brkyvz
Author: Xiangrui Meng <meng@databricks.com>
Closes #6156 from mengxr/SPARK-7647 and squashes the following commits:
1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python
f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel
(cherry picked from commit 723853edab
)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
parent
79983f17d9
commit
f91bb57efa
|
@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
||||||
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
|
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
|
||||||
>>> model.transform(test0).head().prediction
|
>>> model.transform(test0).head().prediction
|
||||||
0.0
|
0.0
|
||||||
|
>>> model.weights
|
||||||
|
DenseVector([5.5...])
|
||||||
|
>>> model.intercept
|
||||||
|
-2.68...
|
||||||
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
|
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
|
||||||
>>> model.transform(test1).head().prediction
|
>>> model.transform(test1).head().prediction
|
||||||
1.0
|
1.0
|
||||||
|
@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel):
|
||||||
Model fitted by LogisticRegression.
|
Model fitted by LogisticRegression.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights(self):
|
||||||
|
"""
|
||||||
|
Model weights.
|
||||||
|
"""
|
||||||
|
return self._call_java("weights")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intercept(self):
|
||||||
|
"""
|
||||||
|
Model intercept.
|
||||||
|
"""
|
||||||
|
return self._call_java("intercept")
|
||||||
|
|
||||||
|
|
||||||
class TreeClassifierParams(object):
|
class TreeClassifierParams(object):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
|
||||||
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
||||||
>>> model.transform(test0).head().prediction
|
>>> model.transform(test0).head().prediction
|
||||||
-1.0
|
-1.0
|
||||||
|
>>> model.weights
|
||||||
|
DenseVector([1.0])
|
||||||
|
>>> model.intercept
|
||||||
|
0.0
|
||||||
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
||||||
>>> model.transform(test1).head().prediction
|
>>> model.transform(test1).head().prediction
|
||||||
1.0
|
1.0
|
||||||
|
@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel):
|
||||||
Model fitted by LinearRegression.
|
Model fitted by LinearRegression.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights(self):
|
||||||
|
"""
|
||||||
|
Model weights.
|
||||||
|
"""
|
||||||
|
return self._call_java("weights")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intercept(self):
|
||||||
|
"""
|
||||||
|
Model intercept.
|
||||||
|
"""
|
||||||
|
return self._call_java("intercept")
|
||||||
|
|
||||||
|
|
||||||
class TreeRegressorParams(object):
|
class TreeRegressorParams(object):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -21,7 +21,7 @@ from pyspark import SparkContext
|
||||||
from pyspark.sql import DataFrame
|
from pyspark.sql import DataFrame
|
||||||
from pyspark.ml.param import Params
|
from pyspark.ml.param import Params
|
||||||
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
|
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
|
||||||
from pyspark.mllib.common import inherit_doc
|
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
|
||||||
|
|
||||||
|
|
||||||
def _jvm():
|
def _jvm():
|
||||||
|
@ -149,6 +149,12 @@ class JavaModel(Model, JavaTransformer):
|
||||||
def _java_obj(self):
|
def _java_obj(self):
|
||||||
return self._java_model
|
return self._java_model
|
||||||
|
|
||||||
|
def _call_java(self, name, *args):
|
||||||
|
m = getattr(self._java_model, name)
|
||||||
|
sc = SparkContext._active_spark_context
|
||||||
|
java_args = [_py2java(sc, arg) for arg in args]
|
||||||
|
return _java2py(sc, m(*java_args))
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class JavaEvaluator(Evaluator, JavaWrapper):
|
class JavaEvaluator(Evaluator, JavaWrapper):
|
||||||
|
|
Loading…
Reference in a new issue