2015-05-05 14:45:37 -04:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
2018-03-08 06:38:34 -05:00
|
|
|
import sys
|
2015-05-22 01:57:33 -04:00
|
|
|
from abc import abstractmethod, ABCMeta
|
|
|
|
|
2016-04-20 13:32:01 -04:00
|
|
|
from pyspark import since, keyword_only
|
2016-04-13 17:08:57 -04:00
|
|
|
from pyspark.ml.wrapper import JavaParams
|
2016-04-15 15:14:41 -04:00
|
|
|
from pyspark.ml.param import Param, Params, TypeConverters
|
2017-09-22 01:12:33 -04:00
|
|
|
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol, \
|
2019-02-08 12:46:54 -05:00
|
|
|
HasFeaturesCol, HasWeightCol
|
2016-06-13 22:59:53 -04:00
|
|
|
from pyspark.ml.common import inherit_doc
|
2016-10-14 07:17:03 -04:00
|
|
|
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
|
2015-05-05 14:45:37 -04:00
|
|
|
|
2015-08-12 16:24:18 -04:00
|
|
|
__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
|
2019-06-13 08:58:22 -04:00
|
|
|
'MulticlassClassificationEvaluator', 'MultilabelClassificationEvaluator',
|
2019-06-25 07:44:06 -04:00
|
|
|
'ClusteringEvaluator', 'RankingEvaluator']
|
2015-05-22 01:57:33 -04:00
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class Evaluator(Params):
|
|
|
|
"""
|
|
|
|
Base class for evaluators that compute metrics from predictions.
|
2015-11-02 19:09:22 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-22 01:57:33 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def _evaluate(self, dataset):
|
|
|
|
"""
|
|
|
|
Evaluates the output.
|
|
|
|
|
|
|
|
:param dataset: a dataset that contains labels/observations and
|
|
|
|
predictions
|
|
|
|
:return: metric
|
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.4.0")
|
2015-08-14 15:46:05 -04:00
|
|
|
def evaluate(self, dataset, params=None):
|
2015-05-22 01:57:33 -04:00
|
|
|
"""
|
|
|
|
Evaluates the output with optional parameters.
|
|
|
|
|
|
|
|
:param dataset: a dataset that contains labels/observations and
|
|
|
|
predictions
|
|
|
|
:param params: an optional param map that overrides embedded
|
|
|
|
params
|
|
|
|
:return: metric
|
|
|
|
"""
|
2015-08-14 15:46:05 -04:00
|
|
|
if params is None:
|
|
|
|
params = dict()
|
2015-05-22 01:57:33 -04:00
|
|
|
if isinstance(params, dict):
|
|
|
|
if params:
|
|
|
|
return self.copy(params)._evaluate(dataset)
|
|
|
|
else:
|
|
|
|
return self._evaluate(dataset)
|
|
|
|
else:
|
|
|
|
raise ValueError("Params must be a param map but got %s." % type(params))
|
|
|
|
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.5.0")
|
2015-08-28 02:59:30 -04:00
|
|
|
def isLargerBetter(self):
|
|
|
|
"""
|
|
|
|
Indicates whether the metric returned by :py:meth:`evaluate` should be maximized
|
|
|
|
(True, default) or minimized (False).
|
|
|
|
A given evaluator may support multiple metrics which may be maximized or minimized.
|
|
|
|
"""
|
|
|
|
return True
|
|
|
|
|
2015-05-22 01:57:33 -04:00
|
|
|
|
|
|
|
@inherit_doc
|
2016-04-13 17:08:57 -04:00
|
|
|
class JavaEvaluator(JavaParams, Evaluator):
|
2015-05-22 01:57:33 -04:00
|
|
|
"""
|
|
|
|
Base class for :py:class:`Evaluator`s that wrap Java/Scala
|
|
|
|
implementations.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
|
|
|
|
def _evaluate(self, dataset):
|
|
|
|
"""
|
|
|
|
Evaluates the output.
|
|
|
|
:param dataset: a dataset that contains labels/observations and predictions.
|
|
|
|
:return: evaluation metric
|
|
|
|
"""
|
|
|
|
self._transfer_params_to_java()
|
|
|
|
return self._java_obj.evaluate(dataset._jdf)
|
2015-05-05 14:45:37 -04:00
|
|
|
|
2015-08-28 02:59:30 -04:00
|
|
|
def isLargerBetter(self):
|
|
|
|
self._transfer_params_to_java()
|
|
|
|
return self._java_obj.isLargerBetter()
|
|
|
|
|
2015-05-05 14:45:37 -04:00
|
|
|
|
|
|
|
@inherit_doc
|
2019-02-25 18:16:51 -05:00
|
|
|
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol,
|
2016-10-14 07:17:03 -04:00
|
|
|
JavaMLReadable, JavaMLWritable):
|
2015-05-05 14:45:37 -04:00
|
|
|
"""
|
2016-01-19 17:59:20 -05:00
|
|
|
Evaluator for binary classification, which expects two input columns: rawPrediction and label.
|
|
|
|
The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label
|
|
|
|
1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities).
|
2015-05-05 14:45:37 -04:00
|
|
|
|
2016-05-17 15:51:07 -04:00
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
2015-05-05 14:45:37 -04:00
|
|
|
>>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]),
|
|
|
|
... [(0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)])
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])
|
2015-05-05 14:45:37 -04:00
|
|
|
...
|
|
|
|
>>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw")
|
|
|
|
>>> evaluator.evaluate(dataset)
|
|
|
|
0.70...
|
|
|
|
>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
|
|
|
|
0.83...
|
2016-10-14 07:17:03 -04:00
|
|
|
>>> bce_path = temp_path + "/bce"
|
|
|
|
>>> evaluator.save(bce_path)
|
|
|
|
>>> evaluator2 = BinaryClassificationEvaluator.load(bce_path)
|
|
|
|
>>> str(evaluator2.getRawPredictionCol())
|
|
|
|
'raw'
|
2019-02-25 18:16:51 -05:00
|
|
|
>>> scoreAndLabelsAndWeight = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1], x[2]),
|
|
|
|
... [(0.1, 0.0, 1.0), (0.1, 1.0, 0.9), (0.4, 0.0, 0.7), (0.6, 0.0, 0.9),
|
|
|
|
... (0.6, 1.0, 1.0), (0.6, 1.0, 0.3), (0.8, 1.0, 1.0)])
|
|
|
|
>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label", "weight"])
|
|
|
|
...
|
|
|
|
>>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw", weightCol="weight")
|
|
|
|
>>> evaluator.evaluate(dataset)
|
|
|
|
0.70...
|
|
|
|
>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
|
|
|
|
0.82...
|
2015-11-02 19:09:22 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-05 14:45:37 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
metricName = Param(Params._dummy(), "metricName",
|
2016-04-15 15:14:41 -04:00
|
|
|
"metric name in evaluation (areaUnderROC|areaUnderPR)",
|
|
|
|
typeConverter=TypeConverters.toString)
|
2015-05-05 14:45:37 -04:00
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
|
2019-02-25 18:16:51 -05:00
|
|
|
metricName="areaUnderROC", weightCol=None):
|
2015-05-05 14:45:37 -04:00
|
|
|
"""
|
|
|
|
__init__(self, rawPredictionCol="rawPrediction", labelCol="label", \
|
2019-02-25 18:16:51 -05:00
|
|
|
metricName="areaUnderROC", weightCol=None)
|
2015-05-05 14:45:37 -04:00
|
|
|
"""
|
|
|
|
super(BinaryClassificationEvaluator, self).__init__()
|
2015-05-18 15:02:18 -04:00
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid)
|
2017-09-19 10:22:35 -04:00
|
|
|
self._setDefault(metricName="areaUnderROC")
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-05 14:45:37 -04:00
|
|
|
self._set(**kwargs)
|
|
|
|
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.4.0")
|
2015-05-05 14:45:37 -04:00
|
|
|
def setMetricName(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`metricName`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(metricName=value)
|
2015-05-05 14:45:37 -04:00
|
|
|
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.4.0")
|
2015-05-05 14:45:37 -04:00
|
|
|
def getMetricName(self):
|
|
|
|
"""
|
|
|
|
Gets the value of metricName or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.metricName)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.4.0")
|
2015-05-05 14:45:37 -04:00
|
|
|
def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
|
2019-02-25 18:16:51 -05:00
|
|
|
metricName="areaUnderROC", weightCol=None):
|
2015-05-05 14:45:37 -04:00
|
|
|
"""
|
|
|
|
setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \
|
2019-02-25 18:16:51 -05:00
|
|
|
metricName="areaUnderROC", weightCol=None)
|
2015-05-05 14:45:37 -04:00
|
|
|
Sets params for binary classification evaluator.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-05 14:45:37 -04:00
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
|
2015-05-24 13:36:02 -04:00
|
|
|
@inherit_doc
|
2019-03-26 10:06:04 -04:00
|
|
|
class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol,
|
2016-10-14 07:17:03 -04:00
|
|
|
JavaMLReadable, JavaMLWritable):
|
2015-05-24 13:36:02 -04:00
|
|
|
"""
|
2019-03-26 10:06:04 -04:00
|
|
|
Evaluator for Regression, which expects input columns prediction, label
|
|
|
|
and an optional weight column.
|
2015-05-24 13:36:02 -04:00
|
|
|
|
|
|
|
>>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5),
|
|
|
|
... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)]
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])
|
2015-05-24 13:36:02 -04:00
|
|
|
...
|
|
|
|
>>> evaluator = RegressionEvaluator(predictionCol="raw")
|
|
|
|
>>> evaluator.evaluate(dataset)
|
2015-08-19 14:35:05 -04:00
|
|
|
2.842...
|
2015-05-24 13:36:02 -04:00
|
|
|
>>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
|
|
|
|
0.993...
|
|
|
|
>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
|
2015-08-19 14:35:05 -04:00
|
|
|
2.649...
|
2016-10-14 07:17:03 -04:00
|
|
|
>>> re_path = temp_path + "/re"
|
|
|
|
>>> evaluator.save(re_path)
|
|
|
|
>>> evaluator2 = RegressionEvaluator.load(re_path)
|
|
|
|
>>> str(evaluator2.getPredictionCol())
|
|
|
|
'raw'
|
2019-03-26 10:06:04 -04:00
|
|
|
>>> scoreAndLabelsAndWeight = [(-28.98343821, -27.0, 1.0), (20.21491975, 21.5, 0.8),
|
|
|
|
... (-25.98418959, -22.0, 1.0), (30.69731842, 33.0, 0.6), (74.69283752, 71.0, 0.2)]
|
|
|
|
>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label", "weight"])
|
|
|
|
...
|
|
|
|
>>> evaluator = RegressionEvaluator(predictionCol="raw", weightCol="weight")
|
|
|
|
>>> evaluator.evaluate(dataset)
|
|
|
|
2.740...
|
2015-11-02 19:09:22 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-24 13:36:02 -04:00
|
|
|
"""
|
|
|
|
metricName = Param(Params._dummy(), "metricName",
|
2016-05-11 02:33:29 -04:00
|
|
|
"""metric name in evaluation - one of:
|
|
|
|
rmse - root mean squared error (default)
|
|
|
|
mse - mean squared error
|
|
|
|
r2 - r^2 metric
|
|
|
|
mae - mean absolute error.""",
|
2016-04-15 15:14:41 -04:00
|
|
|
typeConverter=TypeConverters.toString)
|
2015-05-24 13:36:02 -04:00
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, predictionCol="prediction", labelCol="label",
|
2019-03-26 10:06:04 -04:00
|
|
|
metricName="rmse", weightCol=None):
|
2015-05-24 13:36:02 -04:00
|
|
|
"""
|
|
|
|
__init__(self, predictionCol="prediction", labelCol="label", \
|
2019-03-26 10:06:04 -04:00
|
|
|
metricName="rmse", weightCol=None)
|
2015-05-24 13:36:02 -04:00
|
|
|
"""
|
|
|
|
super(RegressionEvaluator, self).__init__()
|
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid)
|
2017-09-19 10:22:35 -04:00
|
|
|
self._setDefault(metricName="rmse")
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-24 13:36:02 -04:00
|
|
|
self._set(**kwargs)
|
|
|
|
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.4.0")
|
2015-05-24 13:36:02 -04:00
|
|
|
def setMetricName(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`metricName`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(metricName=value)
|
2015-05-24 13:36:02 -04:00
|
|
|
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.4.0")
|
2015-05-24 13:36:02 -04:00
|
|
|
def getMetricName(self):
|
|
|
|
"""
|
|
|
|
Gets the value of metricName or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.metricName)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.4.0")
|
2015-05-24 13:36:02 -04:00
|
|
|
def setParams(self, predictionCol="prediction", labelCol="label",
|
2019-03-26 10:06:04 -04:00
|
|
|
metricName="rmse", weightCol=None):
|
2015-05-24 13:36:02 -04:00
|
|
|
"""
|
2015-05-29 00:26:43 -04:00
|
|
|
setParams(self, predictionCol="prediction", labelCol="label", \
|
2019-03-26 10:06:04 -04:00
|
|
|
metricName="rmse", weightCol=None)
|
2015-05-24 13:36:02 -04:00
|
|
|
Sets params for regression evaluator.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-24 13:36:02 -04:00
|
|
|
return self._set(**kwargs)
|
|
|
|
|
2015-07-31 02:02:11 -04:00
|
|
|
|
|
|
|
@inherit_doc
|
2019-02-08 12:46:54 -05:00
|
|
|
class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol,
|
2016-10-14 07:17:03 -04:00
|
|
|
JavaMLReadable, JavaMLWritable):
|
2015-07-31 02:02:11 -04:00
|
|
|
"""
|
|
|
|
Evaluator for Multiclass Classification, which expects two input
|
|
|
|
columns: prediction and label.
|
2016-05-11 02:33:29 -04:00
|
|
|
|
2015-07-31 02:02:11 -04:00
|
|
|
>>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
|
|
|
|
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])
|
2015-07-31 02:02:11 -04:00
|
|
|
...
|
|
|
|
>>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
|
|
|
|
>>> evaluator.evaluate(dataset)
|
|
|
|
0.66...
|
2016-06-04 08:56:04 -04:00
|
|
|
>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
|
2015-07-31 02:02:11 -04:00
|
|
|
0.66...
|
2019-06-19 09:56:15 -04:00
|
|
|
>>> evaluator.evaluate(dataset, {evaluator.metricName: "truePositiveRateByLabel",
|
|
|
|
... evaluator.metricLabel: 1.0})
|
|
|
|
0.75...
|
2016-10-14 07:17:03 -04:00
|
|
|
>>> mce_path = temp_path + "/mce"
|
|
|
|
>>> evaluator.save(mce_path)
|
|
|
|
>>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path)
|
|
|
|
>>> str(evaluator2.getPredictionCol())
|
|
|
|
'prediction'
|
2019-02-08 12:46:54 -05:00
|
|
|
>>> scoreAndLabelsAndWeight = [(0.0, 0.0, 1.0), (0.0, 1.0, 1.0), (0.0, 0.0, 1.0),
|
|
|
|
... (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0),
|
|
|
|
... (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)]
|
|
|
|
>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["prediction", "label", "weight"])
|
|
|
|
...
|
|
|
|
>>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction",
|
|
|
|
... weightCol="weight")
|
|
|
|
>>> evaluator.evaluate(dataset)
|
|
|
|
0.66...
|
|
|
|
>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
|
|
|
|
0.66...
|
2015-11-02 19:09:22 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.5.0
|
2015-07-31 02:02:11 -04:00
|
|
|
"""
|
|
|
|
metricName = Param(Params._dummy(), "metricName",
|
|
|
|
"metric name in evaluation "
|
2019-06-19 09:56:15 -04:00
|
|
|
"(f1|accuracy|weightedPrecision|weightedRecall|weightedTruePositiveRate|"
|
|
|
|
"weightedFalsePositiveRate|weightedFMeasure|truePositiveRateByLabel|"
|
|
|
|
"falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel)",
|
2016-04-15 15:14:41 -04:00
|
|
|
typeConverter=TypeConverters.toString)
|
2019-06-19 09:56:15 -04:00
|
|
|
metricLabel = Param(Params._dummy(), "metricLabel",
|
|
|
|
"The class whose metric will be computed in truePositiveRateByLabel|"
|
|
|
|
"falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel."
|
|
|
|
" Must be >= 0. The default value is 0.",
|
|
|
|
typeConverter=TypeConverters.toFloat)
|
|
|
|
beta = Param(Params._dummy(), "beta",
|
|
|
|
"The beta value used in weightedFMeasure|fMeasureByLabel."
|
|
|
|
" Must be > 0. The default value is 1.",
|
|
|
|
typeConverter=TypeConverters.toFloat)
|
2015-07-31 02:02:11 -04:00
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, predictionCol="prediction", labelCol="label",
|
2019-06-19 09:56:15 -04:00
|
|
|
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0):
|
2015-07-31 02:02:11 -04:00
|
|
|
"""
|
|
|
|
__init__(self, predictionCol="prediction", labelCol="label", \
|
2019-06-19 09:56:15 -04:00
|
|
|
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0)
|
2015-07-31 02:02:11 -04:00
|
|
|
"""
|
|
|
|
super(MulticlassClassificationEvaluator, self).__init__()
|
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid)
|
2019-06-19 09:56:15 -04:00
|
|
|
self._setDefault(metricName="f1", metricLabel=0.0, beta=1.0)
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-07-31 02:02:11 -04:00
|
|
|
self._set(**kwargs)
|
|
|
|
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:02:11 -04:00
|
|
|
def setMetricName(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`metricName`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(metricName=value)
|
2015-07-31 02:02:11 -04:00
|
|
|
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:02:11 -04:00
|
|
|
def getMetricName(self):
|
|
|
|
"""
|
|
|
|
Gets the value of metricName or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.metricName)
|
|
|
|
|
2019-06-19 09:56:15 -04:00
|
|
|
@since("3.0.0")
|
|
|
|
def setMetricLabel(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`metricLabel`.
|
|
|
|
"""
|
|
|
|
return self._set(metricLabel=value)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def getMetricLabel(self):
|
|
|
|
"""
|
|
|
|
Gets the value of metricLabel or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.metricLabel)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def setBeta(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`beta`.
|
|
|
|
"""
|
|
|
|
return self._set(beta=value)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def getBeta(self):
|
|
|
|
"""
|
|
|
|
Gets the value of beta or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.beta)
|
|
|
|
|
2015-07-31 02:02:11 -04:00
|
|
|
@keyword_only
|
2015-11-02 19:09:22 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:02:11 -04:00
|
|
|
def setParams(self, predictionCol="prediction", labelCol="label",
|
2019-06-19 09:56:15 -04:00
|
|
|
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0):
|
2015-07-31 02:02:11 -04:00
|
|
|
"""
|
|
|
|
setParams(self, predictionCol="prediction", labelCol="label", \
|
2019-06-19 09:56:15 -04:00
|
|
|
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0)
|
2015-07-31 02:02:11 -04:00
|
|
|
Sets params for multiclass classification evaluator.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-07-31 02:02:11 -04:00
|
|
|
return self._set(**kwargs)
|
|
|
|
|
2017-09-22 01:12:33 -04:00
|
|
|
|
2019-06-13 08:58:22 -04:00
|
|
|
@inherit_doc
|
|
|
|
class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
|
|
|
|
JavaMLReadable, JavaMLWritable):
|
|
|
|
"""
|
|
|
|
.. note:: Experimental
|
|
|
|
|
|
|
|
Evaluator for Multilabel Classification, which expects two input
|
|
|
|
columns: prediction and label.
|
|
|
|
|
|
|
|
>>> scoreAndLabels = [([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
|
|
|
|
... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
|
|
|
|
... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]
|
|
|
|
>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])
|
|
|
|
...
|
|
|
|
>>> evaluator = MultilabelClassificationEvaluator(predictionCol="prediction")
|
|
|
|
>>> evaluator.evaluate(dataset)
|
|
|
|
0.63...
|
|
|
|
>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
|
|
|
|
0.54...
|
|
|
|
>>> mlce_path = temp_path + "/mlce"
|
|
|
|
>>> evaluator.save(mlce_path)
|
|
|
|
>>> evaluator2 = MultilabelClassificationEvaluator.load(mlce_path)
|
|
|
|
>>> str(evaluator2.getPredictionCol())
|
|
|
|
'prediction'
|
|
|
|
|
|
|
|
.. versionadded:: 3.0.0
|
|
|
|
"""
|
|
|
|
metricName = Param(Params._dummy(), "metricName",
|
|
|
|
"metric name in evaluation "
|
|
|
|
"(subsetAccuracy|accuracy|hammingLoss|precision|recall|f1Measure|"
|
|
|
|
"precisionByLabel|recallByLabel|f1MeasureByLabel|microPrecision|"
|
|
|
|
"microRecall|microF1Measure)",
|
|
|
|
typeConverter=TypeConverters.toString)
|
2019-06-19 09:56:15 -04:00
|
|
|
metricLabel = Param(Params._dummy(), "metricLabel",
|
|
|
|
"The class whose metric will be computed in precisionByLabel|"
|
2019-06-13 08:58:22 -04:00
|
|
|
"recallByLabel|f1MeasureByLabel. "
|
|
|
|
"Must be >= 0. The default value is 0.",
|
|
|
|
typeConverter=TypeConverters.toFloat)
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, predictionCol="prediction", labelCol="label",
|
2019-06-19 09:56:15 -04:00
|
|
|
metricName="f1Measure", metricLabel=0.0):
|
2019-06-13 08:58:22 -04:00
|
|
|
"""
|
|
|
|
__init__(self, predictionCol="prediction", labelCol="label", \
|
2019-06-19 09:56:15 -04:00
|
|
|
metricName="f1Measure", metricLabel=0.0)
|
2019-06-13 08:58:22 -04:00
|
|
|
"""
|
|
|
|
super(MultilabelClassificationEvaluator, self).__init__()
|
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator", self.uid)
|
2019-06-19 09:56:15 -04:00
|
|
|
self._setDefault(metricName="f1Measure", metricLabel=0.0)
|
2019-06-13 08:58:22 -04:00
|
|
|
kwargs = self._input_kwargs
|
|
|
|
self._set(**kwargs)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def setMetricName(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`metricName`.
|
|
|
|
"""
|
|
|
|
return self._set(metricName=value)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def getMetricName(self):
|
|
|
|
"""
|
|
|
|
Gets the value of metricName or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.metricName)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
2019-06-19 09:56:15 -04:00
|
|
|
def setMetricLabel(self, value):
|
2019-06-13 08:58:22 -04:00
|
|
|
"""
|
2019-06-19 09:56:15 -04:00
|
|
|
Sets the value of :py:attr:`metricLabel`.
|
2019-06-13 08:58:22 -04:00
|
|
|
"""
|
2019-06-19 09:56:15 -04:00
|
|
|
return self._set(metricLabel=value)
|
2019-06-13 08:58:22 -04:00
|
|
|
|
|
|
|
@since("3.0.0")
|
2019-06-19 09:56:15 -04:00
|
|
|
def getMetricLabel(self):
|
2019-06-13 08:58:22 -04:00
|
|
|
"""
|
2019-06-19 09:56:15 -04:00
|
|
|
Gets the value of metricLabel or its default value.
|
2019-06-13 08:58:22 -04:00
|
|
|
"""
|
2019-06-19 09:56:15 -04:00
|
|
|
return self.getOrDefault(self.metricLabel)
|
2019-06-13 08:58:22 -04:00
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
@since("3.0.0")
|
|
|
|
def setParams(self, predictionCol="prediction", labelCol="label",
|
2019-06-19 09:56:15 -04:00
|
|
|
metricName="f1Measure", metricLabel=0.0):
|
2019-06-13 08:58:22 -04:00
|
|
|
"""
|
|
|
|
setParams(self, predictionCol="prediction", labelCol="label", \
|
2019-06-19 09:56:15 -04:00
|
|
|
metricName="f1Measure", metricLabel=0.0)
|
2019-06-13 08:58:22 -04:00
|
|
|
Sets params for multilabel classification evaluator.
|
|
|
|
"""
|
|
|
|
kwargs = self._input_kwargs
|
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
|
2017-09-22 01:12:33 -04:00
|
|
|
@inherit_doc
|
|
|
|
class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
|
|
|
|
JavaMLReadable, JavaMLWritable):
|
|
|
|
"""
|
|
|
|
Evaluator for Clustering results, which expects two input
|
2018-01-25 04:48:11 -05:00
|
|
|
columns: prediction and features. The metric computes the Silhouette
|
|
|
|
measure using the squared Euclidean distance.
|
|
|
|
|
|
|
|
The Silhouette is a measure for the validation of the consistency
|
|
|
|
within clusters. It ranges between 1 and -1, where a value close to
|
|
|
|
1 means that the points in a cluster are close to the other points
|
|
|
|
in the same cluster and far from the points of the other clusters.
|
2017-09-22 01:12:33 -04:00
|
|
|
|
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
|
|
|
>>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]),
|
|
|
|
... [([0.0, 0.5], 0.0), ([0.5, 0.0], 0.0), ([10.0, 11.0], 1.0),
|
|
|
|
... ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)])
|
|
|
|
>>> dataset = spark.createDataFrame(featureAndPredictions, ["features", "prediction"])
|
|
|
|
...
|
|
|
|
>>> evaluator = ClusteringEvaluator(predictionCol="prediction")
|
|
|
|
>>> evaluator.evaluate(dataset)
|
|
|
|
0.9079...
|
|
|
|
>>> ce_path = temp_path + "/ce"
|
|
|
|
>>> evaluator.save(ce_path)
|
|
|
|
>>> evaluator2 = ClusteringEvaluator.load(ce_path)
|
|
|
|
>>> str(evaluator2.getPredictionCol())
|
|
|
|
'prediction'
|
|
|
|
|
|
|
|
.. versionadded:: 2.3.0
|
|
|
|
"""
|
|
|
|
metricName = Param(Params._dummy(), "metricName",
|
|
|
|
"metric name in evaluation (silhouette)",
|
|
|
|
typeConverter=TypeConverters.toString)
|
2018-02-21 13:39:36 -05:00
|
|
|
distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " +
|
|
|
|
"Supported options: 'squaredEuclidean' and 'cosine'.",
|
|
|
|
typeConverter=TypeConverters.toString)
|
2017-09-22 01:12:33 -04:00
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, predictionCol="prediction", featuresCol="features",
|
2018-02-21 13:39:36 -05:00
|
|
|
metricName="silhouette", distanceMeasure="squaredEuclidean"):
|
2017-09-22 01:12:33 -04:00
|
|
|
"""
|
|
|
|
__init__(self, predictionCol="prediction", featuresCol="features", \
|
2018-02-21 13:39:36 -05:00
|
|
|
metricName="silhouette", distanceMeasure="squaredEuclidean")
|
2017-09-22 01:12:33 -04:00
|
|
|
"""
|
|
|
|
super(ClusteringEvaluator, self).__init__()
|
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid)
|
2018-02-21 13:39:36 -05:00
|
|
|
self._setDefault(metricName="silhouette", distanceMeasure="squaredEuclidean")
|
2017-09-22 01:12:33 -04:00
|
|
|
kwargs = self._input_kwargs
|
|
|
|
self._set(**kwargs)
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def setMetricName(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`metricName`.
|
|
|
|
"""
|
|
|
|
return self._set(metricName=value)
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def getMetricName(self):
|
|
|
|
"""
|
|
|
|
Gets the value of metricName or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.metricName)
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
@since("2.3.0")
|
|
|
|
def setParams(self, predictionCol="prediction", featuresCol="features",
|
2018-02-21 13:39:36 -05:00
|
|
|
metricName="silhouette", distanceMeasure="squaredEuclidean"):
|
2017-09-22 01:12:33 -04:00
|
|
|
"""
|
|
|
|
setParams(self, predictionCol="prediction", featuresCol="features", \
|
2018-02-21 13:39:36 -05:00
|
|
|
metricName="silhouette", distanceMeasure="squaredEuclidean")
|
2017-09-22 01:12:33 -04:00
|
|
|
Sets params for clustering evaluator.
|
|
|
|
"""
|
|
|
|
kwargs = self._input_kwargs
|
|
|
|
return self._set(**kwargs)
|
|
|
|
|
2018-02-21 13:39:36 -05:00
|
|
|
@since("2.4.0")
|
|
|
|
def setDistanceMeasure(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`distanceMeasure`.
|
|
|
|
"""
|
|
|
|
return self._set(distanceMeasure=value)
|
|
|
|
|
|
|
|
@since("2.4.0")
|
|
|
|
def getDistanceMeasure(self):
|
|
|
|
"""
|
|
|
|
Gets the value of `distanceMeasure`
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.distanceMeasure)
|
|
|
|
|
|
|
|
|
2019-06-25 07:44:06 -04:00
|
|
|
@inherit_doc
|
|
|
|
class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
|
|
|
|
JavaMLReadable, JavaMLWritable):
|
|
|
|
"""
|
|
|
|
.. note:: Experimental
|
|
|
|
|
|
|
|
Evaluator for Ranking, which expects two input
|
|
|
|
columns: prediction and label.
|
|
|
|
|
|
|
|
>>> scoreAndLabels = [([1.0, 6.0, 2.0, 7.0, 8.0, 3.0, 9.0, 10.0, 4.0, 5.0],
|
|
|
|
... [1.0, 2.0, 3.0, 4.0, 5.0]),
|
|
|
|
... ([4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 8.0, 9.0, 10.0], [1.0, 2.0, 3.0]),
|
|
|
|
... ([1.0, 2.0, 3.0, 4.0, 5.0], [])]
|
|
|
|
>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])
|
|
|
|
...
|
|
|
|
>>> evaluator = RankingEvaluator(predictionCol="prediction")
|
|
|
|
>>> evaluator.evaluate(dataset)
|
|
|
|
0.35...
|
|
|
|
>>> evaluator.evaluate(dataset, {evaluator.metricName: "precisionAtK", evaluator.k: 2})
|
|
|
|
0.33...
|
|
|
|
>>> ranke_path = temp_path + "/ranke"
|
|
|
|
>>> evaluator.save(ranke_path)
|
|
|
|
>>> evaluator2 = RankingEvaluator.load(ranke_path)
|
|
|
|
>>> str(evaluator2.getPredictionCol())
|
|
|
|
'prediction'
|
|
|
|
|
|
|
|
.. versionadded:: 3.0.0
|
|
|
|
"""
|
|
|
|
metricName = Param(Params._dummy(), "metricName",
|
|
|
|
"metric name in evaluation "
|
|
|
|
"(meanAveragePrecision|meanAveragePrecisionAtK|"
|
|
|
|
"precisionAtK|ndcgAtK|recallAtK)",
|
|
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
k = Param(Params._dummy(), "k",
|
|
|
|
"The ranking position value used in meanAveragePrecisionAtK|precisionAtK|"
|
|
|
|
"ndcgAtK|recallAtK. Must be > 0. The default value is 10.",
|
|
|
|
typeConverter=TypeConverters.toInt)
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, predictionCol="prediction", labelCol="label",
|
|
|
|
metricName="meanAveragePrecision", k=10):
|
|
|
|
"""
|
|
|
|
__init__(self, predictionCol="prediction", labelCol="label", \
|
|
|
|
metricName="meanAveragePrecision", k=10)
|
|
|
|
"""
|
|
|
|
super(RankingEvaluator, self).__init__()
|
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.evaluation.RankingEvaluator", self.uid)
|
|
|
|
self._setDefault(metricName="meanAveragePrecision", k=10)
|
|
|
|
kwargs = self._input_kwargs
|
|
|
|
self._set(**kwargs)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def setMetricName(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`metricName`.
|
|
|
|
"""
|
|
|
|
return self._set(metricName=value)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def getMetricName(self):
|
|
|
|
"""
|
|
|
|
Gets the value of metricName or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.metricName)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def setK(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`k`.
|
|
|
|
"""
|
|
|
|
return self._set(k=value)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def getK(self):
|
|
|
|
"""
|
|
|
|
Gets the value of k or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.k)
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
@since("3.0.0")
|
|
|
|
def setParams(self, predictionCol="prediction", labelCol="label",
|
|
|
|
metricName="meanAveragePrecision", k=10):
|
|
|
|
"""
|
|
|
|
setParams(self, predictionCol="prediction", labelCol="label", \
|
|
|
|
metricName="meanAveragePrecision", k=10)
|
|
|
|
Sets params for ranking evaluator.
|
|
|
|
"""
|
|
|
|
kwargs = self._input_kwargs
|
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
|
2015-05-05 14:45:37 -04:00
|
|
|
if __name__ == "__main__":
|
|
|
|
import doctest
|
2016-10-14 07:17:03 -04:00
|
|
|
import tempfile
|
|
|
|
import pyspark.ml.evaluation
|
2016-05-23 21:14:48 -04:00
|
|
|
from pyspark.sql import SparkSession
|
2016-10-14 07:17:03 -04:00
|
|
|
globs = pyspark.ml.evaluation.__dict__.copy()
|
2015-05-05 14:45:37 -04:00
|
|
|
# The small batch size here ensures that we see multiple batches,
|
|
|
|
# even in these small test examples:
|
2016-05-23 21:14:48 -04:00
|
|
|
spark = SparkSession.builder\
|
|
|
|
.master("local[2]")\
|
|
|
|
.appName("ml.evaluation tests")\
|
|
|
|
.getOrCreate()
|
|
|
|
globs['spark'] = spark
|
2016-10-14 07:17:03 -04:00
|
|
|
temp_path = tempfile.mkdtemp()
|
|
|
|
globs['temp_path'] = temp_path
|
|
|
|
try:
|
|
|
|
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
|
|
|
|
spark.stop()
|
|
|
|
finally:
|
|
|
|
from shutil import rmtree
|
|
|
|
try:
|
|
|
|
rmtree(temp_path)
|
|
|
|
except OSError:
|
|
|
|
pass
|
2015-05-05 14:45:37 -04:00
|
|
|
if failure_count:
|
2018-03-08 06:38:34 -05:00
|
|
|
sys.exit(-1)
|