[SPARK-6094] [MLLIB] Add MultilabelMetrics in PySpark/MLlib
Add MultilabelMetrics in PySpark/MLlib
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #6276 from yanboliang/spark-6094 and squashes the following commits:
b8e3343 [Yanbo Liang] Add MultilabelMetrics in PySpark/MLlib
(cherry picked from commit 98a46f9dff
)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
parent
996e2d4b38
commit
606ae3e10e
|
@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation
|
|||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.SparkContext._
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
/**
|
||||
* Evaluator for multilabel classification.
|
||||
|
@ -27,6 +28,13 @@ import org.apache.spark.SparkContext._
|
|||
*/
|
||||
class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
|
||||
|
||||
/**
|
||||
* An auxiliary constructor taking a DataFrame.
|
||||
* @param predictionAndLabels a DataFrame with two double array columns: prediction and label
|
||||
*/
|
||||
private[mllib] def this(predictionAndLabels: DataFrame) =
|
||||
this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray)))
|
||||
|
||||
private lazy val numDocs: Long = predictionAndLabels.count()
|
||||
|
||||
private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
|
||||
|
|
|
@ -343,6 +343,123 @@ class RankingMetrics(JavaModelWrapper):
|
|||
return self.call("ndcgAt", int(k))
|
||||
|
||||
|
||||
class MultilabelMetrics(JavaModelWrapper):
|
||||
"""
|
||||
Evaluator for multilabel classification.
|
||||
|
||||
>>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
|
||||
... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
|
||||
... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
|
||||
>>> metrics = MultilabelMetrics(predictionAndLabels)
|
||||
>>> metrics.precision(0.0)
|
||||
1.0
|
||||
>>> metrics.recall(1.0)
|
||||
0.66...
|
||||
>>> metrics.f1Measure(2.0)
|
||||
0.5
|
||||
>>> metrics.precision()
|
||||
0.66...
|
||||
>>> metrics.recall()
|
||||
0.64...
|
||||
>>> metrics.f1Measure()
|
||||
0.63...
|
||||
>>> metrics.microPrecision
|
||||
0.72...
|
||||
>>> metrics.microRecall
|
||||
0.66...
|
||||
>>> metrics.microF1Measure
|
||||
0.69...
|
||||
>>> metrics.hammingLoss
|
||||
0.33...
|
||||
>>> metrics.subsetAccuracy
|
||||
0.28...
|
||||
>>> metrics.accuracy
|
||||
0.54...
|
||||
"""
|
||||
|
||||
def __init__(self, predictionAndLabels):
|
||||
sc = predictionAndLabels.ctx
|
||||
sql_ctx = SQLContext(sc)
|
||||
df = sql_ctx.createDataFrame(predictionAndLabels,
|
||||
schema=sql_ctx._inferSchema(predictionAndLabels))
|
||||
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics
|
||||
java_model = java_class(df._jdf)
|
||||
super(MultilabelMetrics, self).__init__(java_model)
|
||||
|
||||
def precision(self, label=None):
|
||||
"""
|
||||
Returns precision or precision for a given label (category) if specified.
|
||||
"""
|
||||
if label is None:
|
||||
return self.call("precision")
|
||||
else:
|
||||
return self.call("precision", float(label))
|
||||
|
||||
def recall(self, label=None):
|
||||
"""
|
||||
Returns recall or recall for a given label (category) if specified.
|
||||
"""
|
||||
if label is None:
|
||||
return self.call("recall")
|
||||
else:
|
||||
return self.call("recall", float(label))
|
||||
|
||||
def f1Measure(self, label=None):
|
||||
"""
|
||||
Returns f1Measure or f1Measure for a given label (category) if specified.
|
||||
"""
|
||||
if label is None:
|
||||
return self.call("f1Measure")
|
||||
else:
|
||||
return self.call("f1Measure", float(label))
|
||||
|
||||
@property
|
||||
def microPrecision(self):
|
||||
"""
|
||||
Returns micro-averaged label-based precision.
|
||||
(equals to micro-averaged document-based precision)
|
||||
"""
|
||||
return self.call("microPrecision")
|
||||
|
||||
@property
|
||||
def microRecall(self):
|
||||
"""
|
||||
Returns micro-averaged label-based recall.
|
||||
(equals to micro-averaged document-based recall)
|
||||
"""
|
||||
return self.call("microRecall")
|
||||
|
||||
@property
|
||||
def microF1Measure(self):
|
||||
"""
|
||||
Returns micro-averaged label-based f1-measure.
|
||||
(equals to micro-averaged document-based f1-measure)
|
||||
"""
|
||||
return self.call("microF1Measure")
|
||||
|
||||
@property
|
||||
def hammingLoss(self):
|
||||
"""
|
||||
Returns Hamming-loss.
|
||||
"""
|
||||
return self.call("hammingLoss")
|
||||
|
||||
@property
|
||||
def subsetAccuracy(self):
|
||||
"""
|
||||
Returns subset accuracy.
|
||||
(for equal sets of labels)
|
||||
"""
|
||||
return self.call("subsetAccuracy")
|
||||
|
||||
@property
|
||||
def accuracy(self):
|
||||
"""
|
||||
Returns accuracy.
|
||||
"""
|
||||
return self.call("accuracy")
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark import SparkContext
|
||||
|
|
Loading…
Reference in a new issue