From 5bb9647e1019ea7eb17af7d2057fdacb7f4c560b Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 1 Feb 2019 17:29:58 -0600 Subject: [PATCH] [SPARK-26754][PYTHON] Add hasTrainingSummary to replace duplicate code in PySpark ## What changes were proposed in this pull request? Python version of https://github.com/apache/spark/pull/17654 ## How was this patch tested? Existing Python unit test Closes #23676 from huaxingao/spark26754. Authored-by: Huaxin Gao Signed-off-by: Sean Owen --- python/pyspark/ml/classification.py | 19 +++++---------- python/pyspark/ml/clustering.py | 37 +++++------------------------ python/pyspark/ml/regression.py | 30 +++++------------------ python/pyspark/ml/util.py | 26 ++++++++++++++++++++ 4 files changed, 44 insertions(+), 68 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 89b927814c..134b9e0055 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -483,7 +483,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti return self.getOrDefault(self.upperBoundsOnIntercepts) -class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): +class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable, + HasTrainingSummary): """ Model fitted by LogisticRegression. @@ -532,24 +533,16 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable trained on the training set. An exception is thrown if `trainingSummary is None`. """ if self.hasSummary: - java_lrt_summary = self._call_java("summary") if self.numClasses <= 2: - return BinaryLogisticRegressionTrainingSummary(java_lrt_summary) + return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel, + self).summary) else: - return LogisticRegressionTrainingSummary(java_lrt_summary) + return LogisticRegressionTrainingSummary(super(LogisticRegressionModel, + self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) - @property - @since("2.0.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model - instance. - """ - return self._call_java("hasSummary") - @since("2.0.0") def evaluate(self, dataset): """ diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b9c6bdf521..864e2a3e09 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -97,7 +97,7 @@ class ClusteringSummary(JavaWrapper): return self._call_java("numIter") -class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): +class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary): """ Model fitted by GaussianMixture. @@ -124,15 +124,6 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): """ return self._call_java("gaussiansDF") - @property - @since("2.1.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model - instance. - """ - return self._call_java("hasSummary") - @property @since("2.1.0") def summary(self): @@ -141,7 +132,7 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): training set. An exception is thrown if no summary exists. """ if self.hasSummary: - return GaussianMixtureSummary(self._call_java("summary")) + return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) @@ -323,7 +314,7 @@ class KMeansSummary(ClusteringSummary): return self._call_java("trainingCost") -class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable): +class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary): """ Model fitted by KMeans. @@ -335,14 +326,6 @@ class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] - @property - @since("2.1.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model instance. - """ - return self._call_java("hasSummary") - @property @since("2.1.0") def summary(self): @@ -351,7 +334,7 @@ class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable): training set. An exception is thrown if no summary exists. """ if self.hasSummary: - return KMeansSummary(self._call_java("summary")) + return KMeansSummary(super(KMeansModel, self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) @@ -507,7 +490,7 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol return self.getOrDefault(self.distanceMeasure) -class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): +class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary): """ Model fitted by BisectingKMeans. @@ -534,14 +517,6 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): "dataset in the summary.", DeprecationWarning) return self._call_java("computeCost", dataset) - @property - @since("2.1.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model instance. - """ - return self._call_java("hasSummary") - @property @since("2.1.0") def summary(self): @@ -550,7 +525,7 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): training set. An exception is thrown if no summary exists. """ if self.hasSummary: - return BisectingKMeansSummary(self._call_java("summary")) + return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9e1f8f88ca..7841de9c3d 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -161,7 +161,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction return self.getOrDefault(self.epsilon) -class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable): +class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable, + HasTrainingSummary): """ Model fitted by :class:`LinearRegression`. @@ -201,21 +202,11 @@ class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritabl `trainingSummary is None`. """ if self.hasSummary: - java_lrt_summary = self._call_java("summary") - return LinearRegressionTrainingSummary(java_lrt_summary) + return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) - @property - @since("2.0.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model - instance. - """ - return self._call_java("hasSummary") - @since("2.0.0") def evaluate(self, dataset): """ @@ -1648,7 +1639,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, - JavaMLReadable): + JavaMLReadable, HasTrainingSummary): """ .. note:: Experimental @@ -1682,21 +1673,12 @@ class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWri `trainingSummary is None`. """ if self.hasSummary: - java_glrt_summary = self._call_java("summary") - return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + return GeneralizedLinearRegressionTrainingSummary( + super(GeneralizedLinearRegressionModel, self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) - @property - @since("2.0.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model - instance. - """ - return self._call_java("hasSummary") - @since("2.0.0") def evaluate(self, dataset): """ diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index e846834761..e184e1ac77 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -611,3 +611,29 @@ class DefaultParamsReader(MLReader): py_type = DefaultParamsReader.__get_class(pythonClassName) instance = py_type.load(path) return instance + + +@inherit_doc +class HasTrainingSummary(object): + """ + Base class for models that provides Training summary. + .. versionadded:: 3.0.0 + """ + + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary of the model trained on the training set. An exception is thrown if + no summary exists. + """ + return (self._call_java("summary"))