From 5fa4ba0cfb126bfadee7451fe9a46cee3d60b67c Mon Sep 17 00:00:00 2001 From: masa3141 Date: Wed, 6 Mar 2019 08:28:53 -0600 Subject: [PATCH] [SPARK-26981][MLLIB] Add 'Recall_at_k' metric to RankingMetrics ## What changes were proposed in this pull request? Add 'Recall_at_k' metric to RankingMetrics ## How was this patch tested? Add test to RankingMetricsSuite. Closes #23881 from masa3141/SPARK-26981. Authored-by: masa3141 Signed-off-by: Sean Owen --- .../mllib/JavaRankingMetricsExample.java | 3 +- .../mllib/RankingMetricsExample.scala | 5 ++ .../mllib/evaluation/RankingMetrics.scala | 75 ++++++++++++++----- .../evaluation/RankingMetricsSuite.scala | 14 +++- python/pyspark/mllib/evaluation.py | 20 +++++ 5 files changed, 97 insertions(+), 20 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java index dc9970d885..414d3763dd 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -99,11 +99,12 @@ public class JavaRankingMetricsExample { // Instantiate the metrics object RankingMetrics metrics = RankingMetrics.of(relevantDocs); - // Precision and NDCG at k + // Precision, NDCG and Recall at k Integer[] kVector = {1, 3, 5}; for (Integer k : kVector) { System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + System.out.format("Recall at %d = %f\n", k, metrics.recallAt(k)); } // Mean average precision diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala index d514891da7..34fbe0851d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -89,6 +89,11 @@ object RankingMetricsExample { println(s"NDCG at $k = ${metrics.ndcgAt(k)}") } + // Recall at K + Array(1, 3, 5).foreach { k => + println(s"Recall at $k = ${metrics.recallAt(k)}") + } + // Get predictions for each data point val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 4935d11411..ff9663abcc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -59,23 +59,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] def precisionAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") predictionAndLabels.map { case (pred, lab) => - val labSet = lab.toSet - - if (labSet.nonEmpty) { - val n = math.min(pred.length, k) - var i = 0 - var cnt = 0 - while (i < n) { - if (labSet.contains(pred(i))) { - cnt += 1 - } - i += 1 - } - cnt.toDouble / k - } else { - logWarning("Empty ground truth set, check input data") - 0.0 - } + countRelevantItemRatio(pred, lab, k, k) }.mean() } @@ -157,6 +141,63 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] }.mean() } + /** + * Compute the average recall of all the queries, truncated at ranking position k. + * + * If for a query, the ranking algorithm returns n results, the recall value will be + * computed as #(relevant items retrieved) / #(ground truth set). This formula + * also applies when the size of the ground truth set is less than k. + * + * If a query has an empty ground truth set, zero will be used as recall together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated recall, must be positive + * @return the average recall at the first k ranking positions + */ + @Since("3.0.0") + def recallAt(k: Int): Double = { + require(k > 0, "ranking position k should be positive") + predictionAndLabels.map { case (pred, lab) => + countRelevantItemRatio(pred, lab, k, lab.toSet.size) + }.mean() + } + + /** + * Returns the relevant item ratio computed as #(relevant items retrieved) / denominator. + * If a query has an empty ground truth set, the value will be zero and a log + * warning is generated. + * + * @param pred predicted ranking + * @param lab ground truth + * @param k use the top k predicted ranking, must be positive + * @param denominator the denominator of ratio + * @return relevant item ratio at the first k ranking positions + */ + private def countRelevantItemRatio(pred: Array[T], + lab: Array[T], + k: Int, + denominator: Int): Double = { + val labSet = lab.toSet + if (labSet.nonEmpty) { + val n = math.min(pred.length, k) + var i = 0 + var cnt = 0 + while (i < n) { + if (labSet.contains(pred(i))) { + cnt += 1 + } + i += 1 + } + cnt.toDouble / denominator + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + } } object RankingMetrics { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index f334be2c2b..1969098a51 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.util.TestingUtils._ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { - test("Ranking metrics: MAP, NDCG") { + test("Ranking metrics: MAP, NDCG, Recall") { val predictionAndLabels = sc.parallelize( Seq( (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)), @@ -49,9 +49,17 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps) assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps) assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps) + + assert(metrics.recallAt(1) ~== 1.0/15 absTol eps) + assert(metrics.recallAt(2) ~== 8.0/45 absTol eps) + assert(metrics.recallAt(3) ~== 11.0/45 absTol eps) + assert(metrics.recallAt(4) ~== 11.0/45 absTol eps) + assert(metrics.recallAt(5) ~== 16.0/45 absTol eps) + assert(metrics.recallAt(10) ~== 2.0/3 absTol eps) + assert(metrics.recallAt(15) ~== 2.0/3 absTol eps) } - test("MAP, NDCG with few predictions (SPARK-14886)") { + test("MAP, NDCG, Recall with few predictions (SPARK-14886)") { val predictionAndLabels = sc.parallelize( Seq( (Array(1, 6, 2), Array(1, 2, 3, 4, 5)), @@ -64,6 +72,8 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metrics.precisionAt(2) ~== 0.25 absTol eps) assert(metrics.ndcgAt(1) ~== 0.5 absTol eps) assert(metrics.ndcgAt(2) ~== 0.30657 absTol eps) + assert(metrics.recallAt(1) ~== 0.1 absTol eps) + assert(metrics.recallAt(2) ~== 0.1 absTol eps) } } diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 5d8d20dcfc..171c62ce97 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -373,6 +373,12 @@ class RankingMetrics(JavaModelWrapper): 0.33... >>> metrics.ndcgAt(10) 0.48... + >>> metrics.recallAt(1) + 0.06... + >>> metrics.recallAt(5) + 0.35... + >>> metrics.recallAt(15) + 0.66... .. versionadded:: 1.4.0 """ @@ -422,6 +428,20 @@ class RankingMetrics(JavaModelWrapper): """ return self.call("ndcgAt", int(k)) + @since('3.0.0') + def recallAt(self, k): + """ + Compute the average recall of all the queries, truncated at ranking position k. + + If for a query, the ranking algorithm returns n results, the recall value + will be computed as #(relevant items retrieved) / #(ground truth set). + This formula also applies when the size of the ground truth set is less than k. + + If a query has an empty ground truth set, zero will be used as recall together + with a log warning. + """ + return self.call("recallAt", int(k)) + class MultilabelMetrics(JavaModelWrapper): """