[SPARK-26981][MLLIB] Add 'Recall_at_k' metric to RankingMetrics
## What changes were proposed in this pull request? Add 'Recall_at_k' metric to RankingMetrics ## How was this patch tested? Add test to RankingMetricsSuite. Closes #23881 from masa3141/SPARK-26981. Authored-by: masa3141 <masahiro@kazama.tv> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
9b55722161
commit
5fa4ba0cfb
|
@ -99,11 +99,12 @@ public class JavaRankingMetricsExample {
|
|||
// Instantiate the metrics object
|
||||
RankingMetrics<Integer> metrics = RankingMetrics.of(relevantDocs);
|
||||
|
||||
// Precision and NDCG at k
|
||||
// Precision, NDCG and Recall at k
|
||||
Integer[] kVector = {1, 3, 5};
|
||||
for (Integer k : kVector) {
|
||||
System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k));
|
||||
System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k));
|
||||
System.out.format("Recall at %d = %f\n", k, metrics.recallAt(k));
|
||||
}
|
||||
|
||||
// Mean average precision
|
||||
|
|
|
@ -89,6 +89,11 @@ object RankingMetricsExample {
|
|||
println(s"NDCG at $k = ${metrics.ndcgAt(k)}")
|
||||
}
|
||||
|
||||
// Recall at K
|
||||
Array(1, 3, 5).foreach { k =>
|
||||
println(s"Recall at $k = ${metrics.recallAt(k)}")
|
||||
}
|
||||
|
||||
// Get predictions for each data point
|
||||
val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user,
|
||||
r.product), r.rating))
|
||||
|
|
|
@ -59,23 +59,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
|
|||
def precisionAt(k: Int): Double = {
|
||||
require(k > 0, "ranking position k should be positive")
|
||||
predictionAndLabels.map { case (pred, lab) =>
|
||||
val labSet = lab.toSet
|
||||
|
||||
if (labSet.nonEmpty) {
|
||||
val n = math.min(pred.length, k)
|
||||
var i = 0
|
||||
var cnt = 0
|
||||
while (i < n) {
|
||||
if (labSet.contains(pred(i))) {
|
||||
cnt += 1
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
cnt.toDouble / k
|
||||
} else {
|
||||
logWarning("Empty ground truth set, check input data")
|
||||
0.0
|
||||
}
|
||||
countRelevantItemRatio(pred, lab, k, k)
|
||||
}.mean()
|
||||
}
|
||||
|
||||
|
@ -157,6 +141,63 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
|
|||
}.mean()
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the average recall of all the queries, truncated at ranking position k.
|
||||
*
|
||||
* If for a query, the ranking algorithm returns n results, the recall value will be
|
||||
* computed as #(relevant items retrieved) / #(ground truth set). This formula
|
||||
* also applies when the size of the ground truth set is less than k.
|
||||
*
|
||||
* If a query has an empty ground truth set, zero will be used as recall together with
|
||||
* a log warning.
|
||||
*
|
||||
* See the following paper for detail:
|
||||
*
|
||||
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
|
||||
*
|
||||
* @param k the position to compute the truncated recall, must be positive
|
||||
* @return the average recall at the first k ranking positions
|
||||
*/
|
||||
@Since("3.0.0")
|
||||
def recallAt(k: Int): Double = {
|
||||
require(k > 0, "ranking position k should be positive")
|
||||
predictionAndLabels.map { case (pred, lab) =>
|
||||
countRelevantItemRatio(pred, lab, k, lab.toSet.size)
|
||||
}.mean()
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the relevant item ratio computed as #(relevant items retrieved) / denominator.
|
||||
* If a query has an empty ground truth set, the value will be zero and a log
|
||||
* warning is generated.
|
||||
*
|
||||
* @param pred predicted ranking
|
||||
* @param lab ground truth
|
||||
* @param k use the top k predicted ranking, must be positive
|
||||
* @param denominator the denominator of ratio
|
||||
* @return relevant item ratio at the first k ranking positions
|
||||
*/
|
||||
private def countRelevantItemRatio(pred: Array[T],
|
||||
lab: Array[T],
|
||||
k: Int,
|
||||
denominator: Int): Double = {
|
||||
val labSet = lab.toSet
|
||||
if (labSet.nonEmpty) {
|
||||
val n = math.min(pred.length, k)
|
||||
var i = 0
|
||||
var cnt = 0
|
||||
while (i < n) {
|
||||
if (labSet.contains(pred(i))) {
|
||||
cnt += 1
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
cnt.toDouble / denominator
|
||||
} else {
|
||||
logWarning("Empty ground truth set, check input data")
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object RankingMetrics {
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.spark.mllib.util.TestingUtils._
|
|||
|
||||
class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
test("Ranking metrics: MAP, NDCG") {
|
||||
test("Ranking metrics: MAP, NDCG, Recall") {
|
||||
val predictionAndLabels = sc.parallelize(
|
||||
Seq(
|
||||
(Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)),
|
||||
|
@ -49,9 +49,17 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
|
||||
assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
|
||||
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)
|
||||
|
||||
assert(metrics.recallAt(1) ~== 1.0/15 absTol eps)
|
||||
assert(metrics.recallAt(2) ~== 8.0/45 absTol eps)
|
||||
assert(metrics.recallAt(3) ~== 11.0/45 absTol eps)
|
||||
assert(metrics.recallAt(4) ~== 11.0/45 absTol eps)
|
||||
assert(metrics.recallAt(5) ~== 16.0/45 absTol eps)
|
||||
assert(metrics.recallAt(10) ~== 2.0/3 absTol eps)
|
||||
assert(metrics.recallAt(15) ~== 2.0/3 absTol eps)
|
||||
}
|
||||
|
||||
test("MAP, NDCG with few predictions (SPARK-14886)") {
|
||||
test("MAP, NDCG, Recall with few predictions (SPARK-14886)") {
|
||||
val predictionAndLabels = sc.parallelize(
|
||||
Seq(
|
||||
(Array(1, 6, 2), Array(1, 2, 3, 4, 5)),
|
||||
|
@ -64,6 +72,8 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(metrics.precisionAt(2) ~== 0.25 absTol eps)
|
||||
assert(metrics.ndcgAt(1) ~== 0.5 absTol eps)
|
||||
assert(metrics.ndcgAt(2) ~== 0.30657 absTol eps)
|
||||
assert(metrics.recallAt(1) ~== 0.1 absTol eps)
|
||||
assert(metrics.recallAt(2) ~== 0.1 absTol eps)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -373,6 +373,12 @@ class RankingMetrics(JavaModelWrapper):
|
|||
0.33...
|
||||
>>> metrics.ndcgAt(10)
|
||||
0.48...
|
||||
>>> metrics.recallAt(1)
|
||||
0.06...
|
||||
>>> metrics.recallAt(5)
|
||||
0.35...
|
||||
>>> metrics.recallAt(15)
|
||||
0.66...
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
"""
|
||||
|
@ -422,6 +428,20 @@ class RankingMetrics(JavaModelWrapper):
|
|||
"""
|
||||
return self.call("ndcgAt", int(k))
|
||||
|
||||
@since('3.0.0')
|
||||
def recallAt(self, k):
|
||||
"""
|
||||
Compute the average recall of all the queries, truncated at ranking position k.
|
||||
|
||||
If for a query, the ranking algorithm returns n results, the recall value
|
||||
will be computed as #(relevant items retrieved) / #(ground truth set).
|
||||
This formula also applies when the size of the ground truth set is less than k.
|
||||
|
||||
If a query has an empty ground truth set, zero will be used as recall together
|
||||
with a log warning.
|
||||
"""
|
||||
return self.call("recallAt", int(k))
|
||||
|
||||
|
||||
class MultilabelMetrics(JavaModelWrapper):
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue