[SPARK-6091] [MLLIB] Add MulticlassMetrics in PySpark/MLlib

https://issues.apache.org/jira/browse/SPARK-6091

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #6011 from yanboliang/spark-6091 and squashes the following commits:

bb3e4ba [Yanbo Liang] trigger jenkins
53c045d [Yanbo Liang] keep compatibility for python 2.6
972d5ac [Yanbo Liang] Add MulticlassMetrics in PySpark/MLlib
This commit is contained in:
Yanbo Liang 2015-05-10 00:57:14 -07:00 committed by Xiangrui Meng
parent b13162b364
commit bf7e81a51c
2 changed files with 137 additions and 0 deletions

View file

@ -23,6 +23,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
/**
* ::Experimental::
@ -33,6 +34,13 @@ import org.apache.spark.rdd.RDD
@Experimental
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
/**
* An auxiliary constructor taking a DataFrame.
* @param predictionAndLabels a DataFrame with two double columns: prediction and label
*/
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
private lazy val labelCount: Long = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels

View file

@ -141,6 +141,135 @@ class RegressionMetrics(JavaModelWrapper):
return self.call("r2")
class MulticlassMetrics(JavaModelWrapper):
"""
Evaluator for multiclass classification.
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
>>> metrics = MulticlassMetrics(predictionAndLabels)
>>> metrics.falsePositiveRate(0.0)
0.2...
>>> metrics.precision(1.0)
0.75...
>>> metrics.recall(2.0)
1.0...
>>> metrics.fMeasure(0.0, 2.0)
0.52...
>>> metrics.precision()
0.66...
>>> metrics.recall()
0.66...
>>> metrics.weightedFalsePositiveRate
0.19...
>>> metrics.weightedPrecision
0.68...
>>> metrics.weightedRecall
0.66...
>>> metrics.weightedFMeasure()
0.66...
>>> metrics.weightedFMeasure(2.0)
0.65...
"""
def __init__(self, predictionAndLabels):
"""
:param predictionAndLabels an RDD of (prediction, label) pairs.
"""
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
StructField("prediction", DoubleType(), nullable=False),
StructField("label", DoubleType(), nullable=False)]))
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
java_model = java_class(df._jdf)
super(MulticlassMetrics, self).__init__(java_model)
def truePositiveRate(self, label):
"""
Returns true positive rate for a given label (category).
"""
return self.call("truePositiveRate", label)
def falsePositiveRate(self, label):
"""
Returns false positive rate for a given label (category).
"""
return self.call("falsePositiveRate", label)
def precision(self, label=None):
"""
Returns precision or precision for a given label (category) if specified.
"""
if label is None:
return self.call("precision")
else:
return self.call("precision", float(label))
def recall(self, label=None):
"""
Returns recall or recall for a given label (category) if specified.
"""
if label is None:
return self.call("recall")
else:
return self.call("recall", float(label))
def fMeasure(self, label=None, beta=None):
"""
Returns f-measure or f-measure for a given label (category) if specified.
"""
if beta is None:
if label is None:
return self.call("fMeasure")
else:
return self.call("fMeasure", label)
else:
if label is None:
raise Exception("If the beta parameter is specified, label can not be none")
else:
return self.call("fMeasure", label, beta)
@property
def weightedTruePositiveRate(self):
"""
Returns weighted true positive rate.
(equals to precision, recall and f-measure)
"""
return self.call("weightedTruePositiveRate")
@property
def weightedFalsePositiveRate(self):
"""
Returns weighted false positive rate.
"""
return self.call("weightedFalsePositiveRate")
@property
def weightedRecall(self):
"""
Returns weighted averaged recall.
(equals to precision, recall and f-measure)
"""
return self.call("weightedRecall")
@property
def weightedPrecision(self):
"""
Returns weighted averaged precision.
"""
return self.call("weightedPrecision")
def weightedFMeasure(self, beta=None):
"""
Returns weighted averaged f-measure.
"""
if beta is None:
return self.call("weightedFMeasure")
else:
return self.call("weightedFMeasure", beta)
def _test():
import doctest
from pyspark import SparkContext