[SPARK-6091] [MLLIB] Add MulticlassMetrics in PySpark/MLlib
https://issues.apache.org/jira/browse/SPARK-6091 Author: Yanbo Liang <ybliang8@gmail.com> Closes #6011 from yanboliang/spark-6091 and squashes the following commits: bb3e4ba [Yanbo Liang] trigger jenkins 53c045d [Yanbo Liang] keep compatibility for python 2.6 972d5ac [Yanbo Liang] Add MulticlassMetrics in PySpark/MLlib
This commit is contained in:
parent
b13162b364
commit
bf7e81a51c
|
@ -23,6 +23,7 @@ import org.apache.spark.SparkContext._
|
|||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
/**
|
||||
* ::Experimental::
|
||||
|
@ -33,6 +34,13 @@ import org.apache.spark.rdd.RDD
|
|||
@Experimental
|
||||
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
|
||||
|
||||
/**
|
||||
* An auxiliary constructor taking a DataFrame.
|
||||
* @param predictionAndLabels a DataFrame with two double columns: prediction and label
|
||||
*/
|
||||
private[mllib] def this(predictionAndLabels: DataFrame) =
|
||||
this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
|
||||
|
||||
private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
|
||||
private lazy val labelCount: Long = labelCountByClass.values.sum
|
||||
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
|
||||
|
|
|
@ -141,6 +141,135 @@ class RegressionMetrics(JavaModelWrapper):
|
|||
return self.call("r2")
|
||||
|
||||
|
||||
class MulticlassMetrics(JavaModelWrapper):
|
||||
"""
|
||||
Evaluator for multiclass classification.
|
||||
|
||||
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
|
||||
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
|
||||
>>> metrics = MulticlassMetrics(predictionAndLabels)
|
||||
>>> metrics.falsePositiveRate(0.0)
|
||||
0.2...
|
||||
>>> metrics.precision(1.0)
|
||||
0.75...
|
||||
>>> metrics.recall(2.0)
|
||||
1.0...
|
||||
>>> metrics.fMeasure(0.0, 2.0)
|
||||
0.52...
|
||||
>>> metrics.precision()
|
||||
0.66...
|
||||
>>> metrics.recall()
|
||||
0.66...
|
||||
>>> metrics.weightedFalsePositiveRate
|
||||
0.19...
|
||||
>>> metrics.weightedPrecision
|
||||
0.68...
|
||||
>>> metrics.weightedRecall
|
||||
0.66...
|
||||
>>> metrics.weightedFMeasure()
|
||||
0.66...
|
||||
>>> metrics.weightedFMeasure(2.0)
|
||||
0.65...
|
||||
"""
|
||||
|
||||
def __init__(self, predictionAndLabels):
|
||||
"""
|
||||
:param predictionAndLabels an RDD of (prediction, label) pairs.
|
||||
"""
|
||||
sc = predictionAndLabels.ctx
|
||||
sql_ctx = SQLContext(sc)
|
||||
df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
|
||||
StructField("prediction", DoubleType(), nullable=False),
|
||||
StructField("label", DoubleType(), nullable=False)]))
|
||||
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
|
||||
java_model = java_class(df._jdf)
|
||||
super(MulticlassMetrics, self).__init__(java_model)
|
||||
|
||||
def truePositiveRate(self, label):
|
||||
"""
|
||||
Returns true positive rate for a given label (category).
|
||||
"""
|
||||
return self.call("truePositiveRate", label)
|
||||
|
||||
def falsePositiveRate(self, label):
|
||||
"""
|
||||
Returns false positive rate for a given label (category).
|
||||
"""
|
||||
return self.call("falsePositiveRate", label)
|
||||
|
||||
def precision(self, label=None):
|
||||
"""
|
||||
Returns precision or precision for a given label (category) if specified.
|
||||
"""
|
||||
if label is None:
|
||||
return self.call("precision")
|
||||
else:
|
||||
return self.call("precision", float(label))
|
||||
|
||||
def recall(self, label=None):
|
||||
"""
|
||||
Returns recall or recall for a given label (category) if specified.
|
||||
"""
|
||||
if label is None:
|
||||
return self.call("recall")
|
||||
else:
|
||||
return self.call("recall", float(label))
|
||||
|
||||
def fMeasure(self, label=None, beta=None):
|
||||
"""
|
||||
Returns f-measure or f-measure for a given label (category) if specified.
|
||||
"""
|
||||
if beta is None:
|
||||
if label is None:
|
||||
return self.call("fMeasure")
|
||||
else:
|
||||
return self.call("fMeasure", label)
|
||||
else:
|
||||
if label is None:
|
||||
raise Exception("If the beta parameter is specified, label can not be none")
|
||||
else:
|
||||
return self.call("fMeasure", label, beta)
|
||||
|
||||
@property
|
||||
def weightedTruePositiveRate(self):
|
||||
"""
|
||||
Returns weighted true positive rate.
|
||||
(equals to precision, recall and f-measure)
|
||||
"""
|
||||
return self.call("weightedTruePositiveRate")
|
||||
|
||||
@property
|
||||
def weightedFalsePositiveRate(self):
|
||||
"""
|
||||
Returns weighted false positive rate.
|
||||
"""
|
||||
return self.call("weightedFalsePositiveRate")
|
||||
|
||||
@property
|
||||
def weightedRecall(self):
|
||||
"""
|
||||
Returns weighted averaged recall.
|
||||
(equals to precision, recall and f-measure)
|
||||
"""
|
||||
return self.call("weightedRecall")
|
||||
|
||||
@property
|
||||
def weightedPrecision(self):
|
||||
"""
|
||||
Returns weighted averaged precision.
|
||||
"""
|
||||
return self.call("weightedPrecision")
|
||||
|
||||
def weightedFMeasure(self, beta=None):
|
||||
"""
|
||||
Returns weighted averaged f-measure.
|
||||
"""
|
||||
if beta is None:
|
||||
return self.call("weightedFMeasure")
|
||||
else:
|
||||
return self.call("weightedFMeasure", beta)
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark import SparkContext
|
||||
|
|
Loading…
Reference in a new issue