[SPARK-26754][PYTHON] Add hasTrainingSummary to replace duplicate code in PySpark

## What changes were proposed in this pull request?

Python version of https://github.com/apache/spark/pull/17654

## How was this patch tested?

Existing Python unit test

Closes #23676 from huaxingao/spark26754.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
Huaxin Gao 2019-02-01 17:29:58 -06:00 committed by Sean Owen
parent 03a928cbec
commit 5bb9647e10
4 changed files with 44 additions and 68 deletions

View file

@ -483,7 +483,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
return self.getOrDefault(self.upperBoundsOnIntercepts)
class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable,
HasTrainingSummary):
"""
Model fitted by LogisticRegression.
@ -532,24 +533,16 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
java_lrt_summary = self._call_java("summary")
if self.numClasses <= 2:
return BinaryLogisticRegressionTrainingSummary(java_lrt_summary)
return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel,
self).summary)
else:
return LogisticRegressionTrainingSummary(java_lrt_summary)
return LogisticRegressionTrainingSummary(super(LogisticRegressionModel,
self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
@property
@since("2.0.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model
instance.
"""
return self._call_java("hasSummary")
@since("2.0.0")
def evaluate(self, dataset):
"""

View file

@ -97,7 +97,7 @@ class ClusteringSummary(JavaWrapper):
return self._call_java("numIter")
class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary):
"""
Model fitted by GaussianMixture.
@ -124,15 +124,6 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("gaussiansDF")
@property
@since("2.1.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model
instance.
"""
return self._call_java("hasSummary")
@property
@since("2.1.0")
def summary(self):
@ -141,7 +132,7 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
return GaussianMixtureSummary(self._call_java("summary"))
return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
@ -323,7 +314,7 @@ class KMeansSummary(ClusteringSummary):
return self._call_java("trainingCost")
class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary):
"""
Model fitted by KMeans.
@ -335,14 +326,6 @@ class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
"""Get the cluster centers, represented as a list of NumPy arrays."""
return [c.toArray() for c in self._call_java("clusterCenters")]
@property
@since("2.1.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model instance.
"""
return self._call_java("hasSummary")
@property
@since("2.1.0")
def summary(self):
@ -351,7 +334,7 @@ class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
return KMeansSummary(self._call_java("summary"))
return KMeansSummary(super(KMeansModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
@ -507,7 +490,7 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol
return self.getOrDefault(self.distanceMeasure)
class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary):
"""
Model fitted by BisectingKMeans.
@ -534,14 +517,6 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"dataset in the summary.", DeprecationWarning)
return self._call_java("computeCost", dataset)
@property
@since("2.1.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model instance.
"""
return self._call_java("hasSummary")
@property
@since("2.1.0")
def summary(self):
@ -550,7 +525,7 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
return BisectingKMeansSummary(self._call_java("summary"))
return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)

View file

@ -161,7 +161,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
return self.getOrDefault(self.epsilon)
class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable):
class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable,
HasTrainingSummary):
"""
Model fitted by :class:`LinearRegression`.
@ -201,21 +202,11 @@ class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritabl
`trainingSummary is None`.
"""
if self.hasSummary:
java_lrt_summary = self._call_java("summary")
return LinearRegressionTrainingSummary(java_lrt_summary)
return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
@property
@since("2.0.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model
instance.
"""
return self._call_java("hasSummary")
@since("2.0.0")
def evaluate(self, dataset):
"""
@ -1648,7 +1639,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
JavaMLReadable, HasTrainingSummary):
"""
.. note:: Experimental
@ -1682,21 +1673,12 @@ class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWri
`trainingSummary is None`.
"""
if self.hasSummary:
java_glrt_summary = self._call_java("summary")
return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary)
return GeneralizedLinearRegressionTrainingSummary(
super(GeneralizedLinearRegressionModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
@property
@since("2.0.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model
instance.
"""
return self._call_java("hasSummary")
@since("2.0.0")
def evaluate(self, dataset):
"""

View file

@ -611,3 +611,29 @@ class DefaultParamsReader(MLReader):
py_type = DefaultParamsReader.__get_class(pythonClassName)
instance = py_type.load(path)
return instance
@inherit_doc
class HasTrainingSummary(object):
"""
Base class for models that provides Training summary.
.. versionadded:: 3.0.0
"""
@property
@since("2.1.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model
instance.
"""
return self._call_java("hasSummary")
@property
@since("2.1.0")
def summary(self):
"""
Gets summary of the model trained on the training set. An exception is thrown if
no summary exists.
"""
return (self._call_java("summary"))