[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml

Otherwise, users can only use `transform` on the models. brkyvz

Author: Xiangrui Meng <meng@databricks.com>

Closes #6156 from mengxr/SPARK-7647 and squashes the following commits:

1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python
f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel

(cherry picked from commit 723853edab)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
Xiangrui Meng 2015-05-14 18:13:58 -07:00
parent 79983f17d9
commit f91bb57efa
3 changed files with 43 additions and 1 deletions

View file

@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
>>> model.transform(test0).head().prediction >>> model.transform(test0).head().prediction
0.0 0.0
>>> model.weights
DenseVector([5.5...])
>>> model.intercept
-2.68...
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
>>> model.transform(test1).head().prediction >>> model.transform(test1).head().prediction
1.0 1.0
@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel):
Model fitted by LogisticRegression. Model fitted by LogisticRegression.
""" """
@property
def weights(self):
"""
Model weights.
"""
return self._call_java("weights")
@property
def intercept(self):
"""
Model intercept.
"""
return self._call_java("intercept")
class TreeClassifierParams(object): class TreeClassifierParams(object):
""" """

View file

@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction >>> model.transform(test0).head().prediction
-1.0 -1.0
>>> model.weights
DenseVector([1.0])
>>> model.intercept
0.0
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction >>> model.transform(test1).head().prediction
1.0 1.0
@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel):
Model fitted by LinearRegression. Model fitted by LinearRegression.
""" """
@property
def weights(self):
"""
Model weights.
"""
return self._call_java("weights")
@property
def intercept(self):
"""
Model intercept.
"""
return self._call_java("intercept")
class TreeRegressorParams(object): class TreeRegressorParams(object):
""" """

View file

@ -21,7 +21,7 @@ from pyspark import SparkContext
from pyspark.sql import DataFrame from pyspark.sql import DataFrame
from pyspark.ml.param import Params from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
from pyspark.mllib.common import inherit_doc from pyspark.mllib.common import inherit_doc, _java2py, _py2java
def _jvm(): def _jvm():
@ -149,6 +149,12 @@ class JavaModel(Model, JavaTransformer):
def _java_obj(self): def _java_obj(self):
return self._java_model return self._java_model
def _call_java(self, name, *args):
m = getattr(self._java_model, name)
sc = SparkContext._active_spark_context
java_args = [_py2java(sc, arg) for arg in args]
return _java2py(sc, m(*java_args))
@inherit_doc @inherit_doc
class JavaEvaluator(Evaluator, JavaWrapper): class JavaEvaluator(Evaluator, JavaWrapper):