[SPARK-26754][PYTHON] Add hasTrainingSummary to replace duplicate code in PySpark
## What changes were proposed in this pull request? Python version of https://github.com/apache/spark/pull/17654 ## How was this patch tested? Existing Python unit test Closes #23676 from huaxingao/spark26754. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
03a928cbec
commit
5bb9647e10
|
@ -483,7 +483,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
return self.getOrDefault(self.upperBoundsOnIntercepts)
|
||||
|
||||
|
||||
class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
|
||||
class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable,
|
||||
HasTrainingSummary):
|
||||
"""
|
||||
Model fitted by LogisticRegression.
|
||||
|
||||
|
@ -532,24 +533,16 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable
|
|||
trained on the training set. An exception is thrown if `trainingSummary is None`.
|
||||
"""
|
||||
if self.hasSummary:
|
||||
java_lrt_summary = self._call_java("summary")
|
||||
if self.numClasses <= 2:
|
||||
return BinaryLogisticRegressionTrainingSummary(java_lrt_summary)
|
||||
return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel,
|
||||
self).summary)
|
||||
else:
|
||||
return LogisticRegressionTrainingSummary(java_lrt_summary)
|
||||
return LogisticRegressionTrainingSummary(super(LogisticRegressionModel,
|
||||
self).summary)
|
||||
else:
|
||||
raise RuntimeError("No training summary available for this %s" %
|
||||
self.__class__.__name__)
|
||||
|
||||
@property
|
||||
@since("2.0.0")
|
||||
def hasSummary(self):
|
||||
"""
|
||||
Indicates whether a training summary exists for this model
|
||||
instance.
|
||||
"""
|
||||
return self._call_java("hasSummary")
|
||||
|
||||
@since("2.0.0")
|
||||
def evaluate(self, dataset):
|
||||
"""
|
||||
|
|
|
@ -97,7 +97,7 @@ class ClusteringSummary(JavaWrapper):
|
|||
return self._call_java("numIter")
|
||||
|
||||
|
||||
class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
||||
class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary):
|
||||
"""
|
||||
Model fitted by GaussianMixture.
|
||||
|
||||
|
@ -124,15 +124,6 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
|||
"""
|
||||
return self._call_java("gaussiansDF")
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def hasSummary(self):
|
||||
"""
|
||||
Indicates whether a training summary exists for this model
|
||||
instance.
|
||||
"""
|
||||
return self._call_java("hasSummary")
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def summary(self):
|
||||
|
@ -141,7 +132,7 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
|||
training set. An exception is thrown if no summary exists.
|
||||
"""
|
||||
if self.hasSummary:
|
||||
return GaussianMixtureSummary(self._call_java("summary"))
|
||||
return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
|
||||
else:
|
||||
raise RuntimeError("No training summary available for this %s" %
|
||||
self.__class__.__name__)
|
||||
|
@ -323,7 +314,7 @@ class KMeansSummary(ClusteringSummary):
|
|||
return self._call_java("trainingCost")
|
||||
|
||||
|
||||
class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
|
||||
class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary):
|
||||
"""
|
||||
Model fitted by KMeans.
|
||||
|
||||
|
@ -335,14 +326,6 @@ class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
|
|||
"""Get the cluster centers, represented as a list of NumPy arrays."""
|
||||
return [c.toArray() for c in self._call_java("clusterCenters")]
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def hasSummary(self):
|
||||
"""
|
||||
Indicates whether a training summary exists for this model instance.
|
||||
"""
|
||||
return self._call_java("hasSummary")
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def summary(self):
|
||||
|
@ -351,7 +334,7 @@ class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
|
|||
training set. An exception is thrown if no summary exists.
|
||||
"""
|
||||
if self.hasSummary:
|
||||
return KMeansSummary(self._call_java("summary"))
|
||||
return KMeansSummary(super(KMeansModel, self).summary)
|
||||
else:
|
||||
raise RuntimeError("No training summary available for this %s" %
|
||||
self.__class__.__name__)
|
||||
|
@ -507,7 +490,7 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol
|
|||
return self.getOrDefault(self.distanceMeasure)
|
||||
|
||||
|
||||
class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
||||
class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary):
|
||||
"""
|
||||
Model fitted by BisectingKMeans.
|
||||
|
||||
|
@ -534,14 +517,6 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
|||
"dataset in the summary.", DeprecationWarning)
|
||||
return self._call_java("computeCost", dataset)
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def hasSummary(self):
|
||||
"""
|
||||
Indicates whether a training summary exists for this model instance.
|
||||
"""
|
||||
return self._call_java("hasSummary")
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def summary(self):
|
||||
|
@ -550,7 +525,7 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
|||
training set. An exception is thrown if no summary exists.
|
||||
"""
|
||||
if self.hasSummary:
|
||||
return BisectingKMeansSummary(self._call_java("summary"))
|
||||
return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
|
||||
else:
|
||||
raise RuntimeError("No training summary available for this %s" %
|
||||
self.__class__.__name__)
|
||||
|
|
|
@ -161,7 +161,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
|
|||
return self.getOrDefault(self.epsilon)
|
||||
|
||||
|
||||
class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable):
|
||||
class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable,
|
||||
HasTrainingSummary):
|
||||
"""
|
||||
Model fitted by :class:`LinearRegression`.
|
||||
|
||||
|
@ -201,21 +202,11 @@ class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritabl
|
|||
`trainingSummary is None`.
|
||||
"""
|
||||
if self.hasSummary:
|
||||
java_lrt_summary = self._call_java("summary")
|
||||
return LinearRegressionTrainingSummary(java_lrt_summary)
|
||||
return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
|
||||
else:
|
||||
raise RuntimeError("No training summary available for this %s" %
|
||||
self.__class__.__name__)
|
||||
|
||||
@property
|
||||
@since("2.0.0")
|
||||
def hasSummary(self):
|
||||
"""
|
||||
Indicates whether a training summary exists for this model
|
||||
instance.
|
||||
"""
|
||||
return self._call_java("hasSummary")
|
||||
|
||||
@since("2.0.0")
|
||||
def evaluate(self, dataset):
|
||||
"""
|
||||
|
@ -1648,7 +1639,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
|
|||
|
||||
|
||||
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
|
||||
JavaMLReadable):
|
||||
JavaMLReadable, HasTrainingSummary):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
|
@ -1682,21 +1673,12 @@ class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWri
|
|||
`trainingSummary is None`.
|
||||
"""
|
||||
if self.hasSummary:
|
||||
java_glrt_summary = self._call_java("summary")
|
||||
return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary)
|
||||
return GeneralizedLinearRegressionTrainingSummary(
|
||||
super(GeneralizedLinearRegressionModel, self).summary)
|
||||
else:
|
||||
raise RuntimeError("No training summary available for this %s" %
|
||||
self.__class__.__name__)
|
||||
|
||||
@property
|
||||
@since("2.0.0")
|
||||
def hasSummary(self):
|
||||
"""
|
||||
Indicates whether a training summary exists for this model
|
||||
instance.
|
||||
"""
|
||||
return self._call_java("hasSummary")
|
||||
|
||||
@since("2.0.0")
|
||||
def evaluate(self, dataset):
|
||||
"""
|
||||
|
|
|
@ -611,3 +611,29 @@ class DefaultParamsReader(MLReader):
|
|||
py_type = DefaultParamsReader.__get_class(pythonClassName)
|
||||
instance = py_type.load(path)
|
||||
return instance
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class HasTrainingSummary(object):
|
||||
"""
|
||||
Base class for models that provides Training summary.
|
||||
.. versionadded:: 3.0.0
|
||||
"""
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def hasSummary(self):
|
||||
"""
|
||||
Indicates whether a training summary exists for this model
|
||||
instance.
|
||||
"""
|
||||
return self._call_java("hasSummary")
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def summary(self):
|
||||
"""
|
||||
Gets summary of the model trained on the training set. An exception is thrown if
|
||||
no summary exists.
|
||||
"""
|
||||
return (self._call_java("summary"))
|
||||
|
|
Loading…
Reference in a new issue