diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 8c30ad4b39..f4c4775965 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -32,6 +32,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.evaluation.RankingMetrics import org.apache.spark.mllib.feature._ import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} import org.apache.spark.mllib.linalg._ @@ -50,6 +51,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -923,6 +925,14 @@ private[python] class PythonMLLibAPI extends Serializable { RG.gammaVectorRDD(jsc.sc, shape, scale, numRows, numCols, parts, s) } + /** + * Java stub for the constructor of Python mllib RankingMetrics + */ + def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = { + new RankingMetrics(predictionAndLabels.map( + r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any]))) + } + } diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 36914597de..4c777f2180 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,9 +15,12 @@ # limitations under the License. # -from pyspark.mllib.common import JavaModelWrapper +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc from pyspark.sql import SQLContext -from pyspark.sql.types import StructField, StructType, DoubleType +from pyspark.sql.types import StructField, StructType, DoubleType, IntegerType, ArrayType + +__all__ = ['BinaryClassificationMetrics', 'RegressionMetrics', + 'MulticlassMetrics', 'RankingMetrics'] class BinaryClassificationMetrics(JavaModelWrapper): @@ -270,6 +273,77 @@ class MulticlassMetrics(JavaModelWrapper): return self.call("weightedFMeasure", beta) +class RankingMetrics(JavaModelWrapper): + """ + Evaluator for ranking algorithms. + + >>> predictionAndLabels = sc.parallelize([ + ... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]), + ... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]), + ... ([1, 2, 3, 4, 5], [])]) + >>> metrics = RankingMetrics(predictionAndLabels) + >>> metrics.precisionAt(1) + 0.33... + >>> metrics.precisionAt(5) + 0.26... + >>> metrics.precisionAt(15) + 0.17... + >>> metrics.meanAveragePrecision + 0.35... + >>> metrics.ndcgAt(3) + 0.33... + >>> metrics.ndcgAt(10) + 0.48... + + """ + + def __init__(self, predictionAndLabels): + """ + :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs. + """ + sc = predictionAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndLabels, + schema=sql_ctx._inferSchema(predictionAndLabels)) + java_model = callMLlibFunc("newRankingMetrics", df._jdf) + super(RankingMetrics, self).__init__(java_model) + + def precisionAt(self, k): + """ + Compute the average precision of all the queries, truncated at ranking position k. + + If for a query, the ranking algorithm returns n (n < k) results, the precision value + will be computed as #(relevant items retrieved) / k. This formula also applies when + the size of the ground truth set is less than k. + + If a query has an empty ground truth set, zero will be used as precision together + with a log warning. + """ + return self.call("precisionAt", int(k)) + + @property + def meanAveragePrecision(self): + """ + Returns the mean average precision (MAP) of all the queries. + If a query has an empty ground truth set, the average precision will be zero and + a log warining is generated. + """ + return self.call("meanAveragePrecision") + + def ndcgAt(self, k): + """ + Compute the average NDCG value of all the queries, truncated at ranking position k. + The discounted cumulative gain at position k is computed as: + sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), + and the NDCG is obtained by dividing the DCG value on the ground truth set. + In the current implementation, the relevance value is binary. + + If a query has an empty ground truth set, zero will be used as ndcg together with + a log warning. + """ + return self.call("ndcgAt", int(k)) + + def _test(): import doctest from pyspark import SparkContext