[SPARK-15113][PYSPARK][ML] Add missing num features num classes

## What changes were proposed in this pull request?

Add missing `numFeatures` and `numClasses` to the wrapped Java models in PySpark ML pipelines. Also tag `DecisionTreeClassificationModel` as Expiremental to match Scala doc.

## How was this patch tested?

Extended doctests

Author: Holden Karau <holden@us.ibm.com>

Closes #12889 from holdenk/SPARK-15113-add-missing-numFeatures-numClasses.
This commit is contained in:
Holden Karau 2016-08-22 12:21:22 +02:00 committed by Nick Pentreath
parent bd9655063b
commit b264cbb16f
4 changed files with 66 additions and 11 deletions

View file

@ -788,6 +788,8 @@ class GeneralizedLinearRegressionModel private[ml] (
@Since("2.0.0")
override def write: MLWriter =
new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
override val numFeatures: Int = coefficients.size
}
@Since("2.0.0")

View file

@ -43,6 +43,23 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel',
'OneVsRest', 'OneVsRestModel']
@inherit_doc
class JavaClassificationModel(JavaPredictionModel):
"""
(Private) Java Model produced by a ``Classifier``.
Classes are indexed {0, 1, ..., numClasses - 1}.
To be mixed in with class:`pyspark.ml.JavaModel`
"""
@property
@since("2.1.0")
def numClasses(self):
"""
Number of classes (values which the label can take).
"""
return self._call_java("numClasses")
@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
@ -212,7 +229,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LogisticRegression.
@ -522,6 +539,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> model.numFeatures
1
>>> model.numClasses
2
>>> print(model.toDebugString)
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@ -595,7 +616,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by DecisionTreeClassifier.
@ -722,7 +744,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
return RandomForestClassificationModel(java_model)
class RandomForestClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by RandomForestClassifier.
@ -873,7 +896,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
return self.getOrDefault(self.lossType)
class GBTClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by GBTClassifier.
@ -1027,7 +1051,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
return self.getOrDefault(self.modelType)
class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by NaiveBayes.
@ -1226,7 +1250,8 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
return self.getOrDefault(self.initialWeights)
class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable):
class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
"""
.. note:: Experimental

View file

@ -88,6 +88,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
True
>>> model.intercept == model2.intercept
True
>>> model.numFeatures
1
.. versionadded:: 1.4.0
"""
@ -126,7 +128,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
return LinearRegressionModel(java_model)
class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`LinearRegression`.
@ -654,6 +656,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
3
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> model.numFeatures
1
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@ -719,7 +723,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
@inherit_doc
class DecisionTreeModel(JavaModel):
class DecisionTreeModel(JavaModel, JavaPredictionModel):
"""
Abstraction for Decision Tree models.
@ -843,6 +847,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
>>> model.numFeatures
1
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
>>> model.getNumTrees
@ -909,7 +915,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return RandomForestRegressionModel(java_model)
class RandomForestRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by :class:`RandomForestRegressor`.
@ -958,6 +965,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
>>> model = gbt.fit(df)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> model.numFeatures
1
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@ -1047,7 +1056,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
return self.getOrDefault(self.lossType)
class GBTRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`GBTRegressor`.
@ -1307,6 +1316,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
True
>>> model.coefficients
DenseVector([1.5..., -1.0...])
>>> model.numFeatures
2
>>> abs(model.intercept - 1.5) < 0.001
True
>>> glr_path = temp_path + "/glr"
@ -1412,7 +1423,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
return self.getOrDefault(self.link)
class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
"""
.. note:: Experimental

View file

@ -238,3 +238,19 @@ class JavaMLReadable(MLReadable):
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
@inherit_doc
class JavaPredictionModel():
"""
(Private) Java Model for prediction tasks (regression and classification).
To be mixed in with class:`pyspark.ml.JavaModel`
"""
@property
@since("2.1.0")
def numFeatures(self):
"""
Returns the number of features the model was trained on. If unknown, returns -1
"""
return self._call_java("numFeatures")