[SPARK-26981][MLLIB] Add 'Recall_at_k' metric to RankingMetrics

## What changes were proposed in this pull request?

Add 'Recall_at_k' metric to RankingMetrics

## How was this patch tested?

Add test to RankingMetricsSuite.

Closes #23881 from masa3141/SPARK-26981.

Authored-by: masa3141 <masahiro@kazama.tv>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
masa3141 2019-03-06 08:28:53 -06:00 committed by Sean Owen
parent 9b55722161
commit 5fa4ba0cfb
5 changed files with 97 additions and 20 deletions

View file

@ -99,11 +99,12 @@ public class JavaRankingMetricsExample {
// Instantiate the metrics object
RankingMetrics<Integer> metrics = RankingMetrics.of(relevantDocs);
// Precision and NDCG at k
// Precision, NDCG and Recall at k
Integer[] kVector = {1, 3, 5};
for (Integer k : kVector) {
System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k));
System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k));
System.out.format("Recall at %d = %f\n", k, metrics.recallAt(k));
}
// Mean average precision

View file

@ -89,6 +89,11 @@ object RankingMetricsExample {
println(s"NDCG at $k = ${metrics.ndcgAt(k)}")
}
// Recall at K
Array(1, 3, 5).foreach { k =>
println(s"Recall at $k = ${metrics.recallAt(k)}")
}
// Get predictions for each data point
val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user,
r.product), r.rating))

View file

@ -59,23 +59,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
def precisionAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet
if (labSet.nonEmpty) {
val n = math.min(pred.length, k)
var i = 0
var cnt = 0
while (i < n) {
if (labSet.contains(pred(i))) {
cnt += 1
}
i += 1
}
cnt.toDouble / k
} else {
logWarning("Empty ground truth set, check input data")
0.0
}
countRelevantItemRatio(pred, lab, k, k)
}.mean()
}
@ -157,6 +141,63 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
}.mean()
}
/**
* Compute the average recall of all the queries, truncated at ranking position k.
*
* If for a query, the ranking algorithm returns n results, the recall value will be
* computed as #(relevant items retrieved) / #(ground truth set). This formula
* also applies when the size of the ground truth set is less than k.
*
* If a query has an empty ground truth set, zero will be used as recall together with
* a log warning.
*
* See the following paper for detail:
*
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
*
* @param k the position to compute the truncated recall, must be positive
* @return the average recall at the first k ranking positions
*/
@Since("3.0.0")
def recallAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
predictionAndLabels.map { case (pred, lab) =>
countRelevantItemRatio(pred, lab, k, lab.toSet.size)
}.mean()
}
/**
* Returns the relevant item ratio computed as #(relevant items retrieved) / denominator.
* If a query has an empty ground truth set, the value will be zero and a log
* warning is generated.
*
* @param pred predicted ranking
* @param lab ground truth
* @param k use the top k predicted ranking, must be positive
* @param denominator the denominator of ratio
* @return relevant item ratio at the first k ranking positions
*/
private def countRelevantItemRatio(pred: Array[T],
lab: Array[T],
k: Int,
denominator: Int): Double = {
val labSet = lab.toSet
if (labSet.nonEmpty) {
val n = math.min(pred.length, k)
var i = 0
var cnt = 0
while (i < n) {
if (labSet.contains(pred(i))) {
cnt += 1
}
i += 1
}
cnt.toDouble / denominator
} else {
logWarning("Empty ground truth set, check input data")
0.0
}
}
}
object RankingMetrics {

View file

@ -23,7 +23,7 @@ import org.apache.spark.mllib.util.TestingUtils._
class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Ranking metrics: MAP, NDCG") {
test("Ranking metrics: MAP, NDCG, Recall") {
val predictionAndLabels = sc.parallelize(
Seq(
(Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)),
@ -49,9 +49,17 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)
assert(metrics.recallAt(1) ~== 1.0/15 absTol eps)
assert(metrics.recallAt(2) ~== 8.0/45 absTol eps)
assert(metrics.recallAt(3) ~== 11.0/45 absTol eps)
assert(metrics.recallAt(4) ~== 11.0/45 absTol eps)
assert(metrics.recallAt(5) ~== 16.0/45 absTol eps)
assert(metrics.recallAt(10) ~== 2.0/3 absTol eps)
assert(metrics.recallAt(15) ~== 2.0/3 absTol eps)
}
test("MAP, NDCG with few predictions (SPARK-14886)") {
test("MAP, NDCG, Recall with few predictions (SPARK-14886)") {
val predictionAndLabels = sc.parallelize(
Seq(
(Array(1, 6, 2), Array(1, 2, 3, 4, 5)),
@ -64,6 +72,8 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(metrics.precisionAt(2) ~== 0.25 absTol eps)
assert(metrics.ndcgAt(1) ~== 0.5 absTol eps)
assert(metrics.ndcgAt(2) ~== 0.30657 absTol eps)
assert(metrics.recallAt(1) ~== 0.1 absTol eps)
assert(metrics.recallAt(2) ~== 0.1 absTol eps)
}
}

View file

@ -373,6 +373,12 @@ class RankingMetrics(JavaModelWrapper):
0.33...
>>> metrics.ndcgAt(10)
0.48...
>>> metrics.recallAt(1)
0.06...
>>> metrics.recallAt(5)
0.35...
>>> metrics.recallAt(15)
0.66...
.. versionadded:: 1.4.0
"""
@ -422,6 +428,20 @@ class RankingMetrics(JavaModelWrapper):
"""
return self.call("ndcgAt", int(k))
@since('3.0.0')
def recallAt(self, k):
"""
Compute the average recall of all the queries, truncated at ranking position k.
If for a query, the ranking algorithm returns n results, the recall value
will be computed as #(relevant items retrieved) / #(ground truth set).
This formula also applies when the size of the ground truth set is less than k.
If a query has an empty ground truth set, zero will be used as recall together
with a log warning.
"""
return self.call("recallAt", int(k))
class MultilabelMetrics(JavaModelWrapper):
"""