[SPARK-6092] [MLLIB] Add RankingMetrics in PySpark/MLlib

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #6044 from yanboliang/spark-6092 and squashes the following commits:

726a9b1 [Yanbo Liang] add newRankingMetrics
33f649c [Yanbo Liang] Add RankingMetrics in PySpark/MLlib

(cherry picked from commit 042dda3c5c)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
Yanbo Liang 2015-05-11 09:14:20 -07:00 committed by Xiangrui Meng
parent da1be15cc6
commit 017f9fa674
2 changed files with 86 additions and 2 deletions

View file

@ -32,6 +32,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.evaluation.RankingMetrics
import org.apache.spark.mllib.feature._
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
import org.apache.spark.mllib.linalg._
@ -50,6 +51,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@ -923,6 +925,14 @@ private[python] class PythonMLLibAPI extends Serializable {
RG.gammaVectorRDD(jsc.sc, shape, scale, numRows, numCols, parts, s)
}
/**
* Java stub for the constructor of Python mllib RankingMetrics
*/
def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = {
new RankingMetrics(predictionAndLabels.map(
r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any])))
}
}

View file

@ -15,9 +15,12 @@
# limitations under the License.
#
from pyspark.mllib.common import JavaModelWrapper
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
from pyspark.sql import SQLContext
from pyspark.sql.types import StructField, StructType, DoubleType
from pyspark.sql.types import StructField, StructType, DoubleType, IntegerType, ArrayType
__all__ = ['BinaryClassificationMetrics', 'RegressionMetrics',
'MulticlassMetrics', 'RankingMetrics']
class BinaryClassificationMetrics(JavaModelWrapper):
@ -270,6 +273,77 @@ class MulticlassMetrics(JavaModelWrapper):
return self.call("weightedFMeasure", beta)
class RankingMetrics(JavaModelWrapper):
"""
Evaluator for ranking algorithms.
>>> predictionAndLabels = sc.parallelize([
... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
... ([1, 2, 3, 4, 5], [])])
>>> metrics = RankingMetrics(predictionAndLabels)
>>> metrics.precisionAt(1)
0.33...
>>> metrics.precisionAt(5)
0.26...
>>> metrics.precisionAt(15)
0.17...
>>> metrics.meanAveragePrecision
0.35...
>>> metrics.ndcgAt(3)
0.33...
>>> metrics.ndcgAt(10)
0.48...
"""
def __init__(self, predictionAndLabels):
"""
:param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs.
"""
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels,
schema=sql_ctx._inferSchema(predictionAndLabels))
java_model = callMLlibFunc("newRankingMetrics", df._jdf)
super(RankingMetrics, self).__init__(java_model)
def precisionAt(self, k):
"""
Compute the average precision of all the queries, truncated at ranking position k.
If for a query, the ranking algorithm returns n (n < k) results, the precision value
will be computed as #(relevant items retrieved) / k. This formula also applies when
the size of the ground truth set is less than k.
If a query has an empty ground truth set, zero will be used as precision together
with a log warning.
"""
return self.call("precisionAt", int(k))
@property
def meanAveragePrecision(self):
"""
Returns the mean average precision (MAP) of all the queries.
If a query has an empty ground truth set, the average precision will be zero and
a log warining is generated.
"""
return self.call("meanAveragePrecision")
def ndcgAt(self, k):
"""
Compute the average NDCG value of all the queries, truncated at ranking position k.
The discounted cumulative gain at position k is computed as:
sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
and the NDCG is obtained by dividing the DCG value on the ground truth set.
In the current implementation, the relevance value is binary.
If a query has an empty ground truth set, zero will be used as ndcg together with
a log warning.
"""
return self.call("ndcgAt", int(k))
def _test():
import doctest
from pyspark import SparkContext