[SPARK-15113][PYSPARK][ML] Add missing num features num classes
## What changes were proposed in this pull request? Add missing `numFeatures` and `numClasses` to the wrapped Java models in PySpark ML pipelines. Also tag `DecisionTreeClassificationModel` as Expiremental to match Scala doc. ## How was this patch tested? Extended doctests Author: Holden Karau <holden@us.ibm.com> Closes #12889 from holdenk/SPARK-15113-add-missing-numFeatures-numClasses.
This commit is contained in:
parent
bd9655063b
commit
b264cbb16f
|
@ -788,6 +788,8 @@ class GeneralizedLinearRegressionModel private[ml] (
|
|||
@Since("2.0.0")
|
||||
override def write: MLWriter =
|
||||
new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
|
||||
|
||||
override val numFeatures: Int = coefficients.size
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
|
|
|
@ -43,6 +43,23 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel',
|
|||
'OneVsRest', 'OneVsRestModel']
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class JavaClassificationModel(JavaPredictionModel):
|
||||
"""
|
||||
(Private) Java Model produced by a ``Classifier``.
|
||||
Classes are indexed {0, 1, ..., numClasses - 1}.
|
||||
To be mixed in with class:`pyspark.ml.JavaModel`
|
||||
"""
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def numClasses(self):
|
||||
"""
|
||||
Number of classes (values which the label can take).
|
||||
"""
|
||||
return self._call_java("numClasses")
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
|
||||
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
|
||||
|
@ -212,7 +229,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
|
||||
|
||||
|
||||
class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
||||
class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Model fitted by LogisticRegression.
|
||||
|
||||
|
@ -522,6 +539,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
|
|||
1
|
||||
>>> model.featureImportances
|
||||
SparseVector(1, {0: 1.0})
|
||||
>>> model.numFeatures
|
||||
1
|
||||
>>> model.numClasses
|
||||
2
|
||||
>>> print(model.toDebugString)
|
||||
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
|
||||
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
||||
|
@ -595,7 +616,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
|
|||
|
||||
|
||||
@inherit_doc
|
||||
class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
|
||||
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
|
||||
JavaMLReadable):
|
||||
"""
|
||||
Model fitted by DecisionTreeClassifier.
|
||||
|
||||
|
@ -722,7 +744,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
|
|||
return RandomForestClassificationModel(java_model)
|
||||
|
||||
|
||||
class RandomForestClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
|
||||
class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
|
||||
JavaMLReadable):
|
||||
"""
|
||||
Model fitted by RandomForestClassifier.
|
||||
|
||||
|
@ -873,7 +896,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
|
|||
return self.getOrDefault(self.lossType)
|
||||
|
||||
|
||||
class GBTClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
|
||||
class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
|
||||
JavaMLReadable):
|
||||
"""
|
||||
Model fitted by GBTClassifier.
|
||||
|
||||
|
@ -1027,7 +1051,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
|
|||
return self.getOrDefault(self.modelType)
|
||||
|
||||
|
||||
class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
||||
class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Model fitted by NaiveBayes.
|
||||
|
||||
|
@ -1226,7 +1250,8 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
|
|||
return self.getOrDefault(self.initialWeights)
|
||||
|
||||
|
||||
class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
||||
class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, JavaMLWritable,
|
||||
JavaMLReadable):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
|
|
|
@ -88,6 +88,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
|
|||
True
|
||||
>>> model.intercept == model2.intercept
|
||||
True
|
||||
>>> model.numFeatures
|
||||
1
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
"""
|
||||
|
@ -126,7 +128,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
|
|||
return LinearRegressionModel(java_model)
|
||||
|
||||
|
||||
class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
||||
class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Model fitted by :class:`LinearRegression`.
|
||||
|
||||
|
@ -654,6 +656,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
|
|||
3
|
||||
>>> model.featureImportances
|
||||
SparseVector(1, {0: 1.0})
|
||||
>>> model.numFeatures
|
||||
1
|
||||
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
||||
>>> model.transform(test0).head().prediction
|
||||
0.0
|
||||
|
@ -719,7 +723,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
|
|||
|
||||
|
||||
@inherit_doc
|
||||
class DecisionTreeModel(JavaModel):
|
||||
class DecisionTreeModel(JavaModel, JavaPredictionModel):
|
||||
"""
|
||||
Abstraction for Decision Tree models.
|
||||
|
||||
|
@ -843,6 +847,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
|
|||
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
||||
>>> model.transform(test0).head().prediction
|
||||
0.0
|
||||
>>> model.numFeatures
|
||||
1
|
||||
>>> model.trees
|
||||
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
|
||||
>>> model.getNumTrees
|
||||
|
@ -909,7 +915,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
|
|||
return RandomForestRegressionModel(java_model)
|
||||
|
||||
|
||||
class RandomForestRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
|
||||
class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
|
||||
JavaMLReadable):
|
||||
"""
|
||||
Model fitted by :class:`RandomForestRegressor`.
|
||||
|
||||
|
@ -958,6 +965,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
|||
>>> model = gbt.fit(df)
|
||||
>>> model.featureImportances
|
||||
SparseVector(1, {0: 1.0})
|
||||
>>> model.numFeatures
|
||||
1
|
||||
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
|
||||
True
|
||||
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
||||
|
@ -1047,7 +1056,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
|||
return self.getOrDefault(self.lossType)
|
||||
|
||||
|
||||
class GBTRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
|
||||
class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Model fitted by :class:`GBTRegressor`.
|
||||
|
||||
|
@ -1307,6 +1316,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
|
|||
True
|
||||
>>> model.coefficients
|
||||
DenseVector([1.5..., -1.0...])
|
||||
>>> model.numFeatures
|
||||
2
|
||||
>>> abs(model.intercept - 1.5) < 0.001
|
||||
True
|
||||
>>> glr_path = temp_path + "/glr"
|
||||
|
@ -1412,7 +1423,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
|
|||
return self.getOrDefault(self.link)
|
||||
|
||||
|
||||
class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
||||
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
|
||||
JavaMLReadable):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
|
|
|
@ -238,3 +238,19 @@ class JavaMLReadable(MLReadable):
|
|||
def read(cls):
|
||||
"""Returns an MLReader instance for this class."""
|
||||
return JavaMLReader(cls)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class JavaPredictionModel():
|
||||
"""
|
||||
(Private) Java Model for prediction tasks (regression and classification).
|
||||
To be mixed in with class:`pyspark.ml.JavaModel`
|
||||
"""
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def numFeatures(self):
|
||||
"""
|
||||
Returns the number of features the model was trained on. If unknown, returns -1
|
||||
"""
|
||||
return self._call_java("numFeatures")
|
||||
|
|
Loading…
Reference in a new issue