[SPARK-6092] [MLLIB] Add RankingMetrics in PySpark/MLlib
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #6044 from yanboliang/spark-6092 and squashes the following commits:
726a9b1 [Yanbo Liang] add newRankingMetrics
33f649c [Yanbo Liang] Add RankingMetrics in PySpark/MLlib
(cherry picked from commit 042dda3c5c
)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
parent
da1be15cc6
commit
017f9fa674
|
@ -32,6 +32,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
|
|||
import org.apache.spark.api.python.SerDeUtil
|
||||
import org.apache.spark.mllib.classification._
|
||||
import org.apache.spark.mllib.clustering._
|
||||
import org.apache.spark.mllib.evaluation.RankingMetrics
|
||||
import org.apache.spark.mllib.feature._
|
||||
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
|
||||
import org.apache.spark.mllib.linalg._
|
||||
|
@ -50,6 +51,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree
|
|||
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
@ -923,6 +925,14 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
RG.gammaVectorRDD(jsc.sc, shape, scale, numRows, numCols, parts, s)
|
||||
}
|
||||
|
||||
/**
|
||||
* Java stub for the constructor of Python mllib RankingMetrics
|
||||
*/
|
||||
def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = {
|
||||
new RankingMetrics(predictionAndLabels.map(
|
||||
r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any])))
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -15,9 +15,12 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from pyspark.mllib.common import JavaModelWrapper
|
||||
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
|
||||
from pyspark.sql import SQLContext
|
||||
from pyspark.sql.types import StructField, StructType, DoubleType
|
||||
from pyspark.sql.types import StructField, StructType, DoubleType, IntegerType, ArrayType
|
||||
|
||||
__all__ = ['BinaryClassificationMetrics', 'RegressionMetrics',
|
||||
'MulticlassMetrics', 'RankingMetrics']
|
||||
|
||||
|
||||
class BinaryClassificationMetrics(JavaModelWrapper):
|
||||
|
@ -270,6 +273,77 @@ class MulticlassMetrics(JavaModelWrapper):
|
|||
return self.call("weightedFMeasure", beta)
|
||||
|
||||
|
||||
class RankingMetrics(JavaModelWrapper):
|
||||
"""
|
||||
Evaluator for ranking algorithms.
|
||||
|
||||
>>> predictionAndLabels = sc.parallelize([
|
||||
... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
|
||||
... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
|
||||
... ([1, 2, 3, 4, 5], [])])
|
||||
>>> metrics = RankingMetrics(predictionAndLabels)
|
||||
>>> metrics.precisionAt(1)
|
||||
0.33...
|
||||
>>> metrics.precisionAt(5)
|
||||
0.26...
|
||||
>>> metrics.precisionAt(15)
|
||||
0.17...
|
||||
>>> metrics.meanAveragePrecision
|
||||
0.35...
|
||||
>>> metrics.ndcgAt(3)
|
||||
0.33...
|
||||
>>> metrics.ndcgAt(10)
|
||||
0.48...
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, predictionAndLabels):
|
||||
"""
|
||||
:param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs.
|
||||
"""
|
||||
sc = predictionAndLabels.ctx
|
||||
sql_ctx = SQLContext(sc)
|
||||
df = sql_ctx.createDataFrame(predictionAndLabels,
|
||||
schema=sql_ctx._inferSchema(predictionAndLabels))
|
||||
java_model = callMLlibFunc("newRankingMetrics", df._jdf)
|
||||
super(RankingMetrics, self).__init__(java_model)
|
||||
|
||||
def precisionAt(self, k):
|
||||
"""
|
||||
Compute the average precision of all the queries, truncated at ranking position k.
|
||||
|
||||
If for a query, the ranking algorithm returns n (n < k) results, the precision value
|
||||
will be computed as #(relevant items retrieved) / k. This formula also applies when
|
||||
the size of the ground truth set is less than k.
|
||||
|
||||
If a query has an empty ground truth set, zero will be used as precision together
|
||||
with a log warning.
|
||||
"""
|
||||
return self.call("precisionAt", int(k))
|
||||
|
||||
@property
|
||||
def meanAveragePrecision(self):
|
||||
"""
|
||||
Returns the mean average precision (MAP) of all the queries.
|
||||
If a query has an empty ground truth set, the average precision will be zero and
|
||||
a log warining is generated.
|
||||
"""
|
||||
return self.call("meanAveragePrecision")
|
||||
|
||||
def ndcgAt(self, k):
|
||||
"""
|
||||
Compute the average NDCG value of all the queries, truncated at ranking position k.
|
||||
The discounted cumulative gain at position k is computed as:
|
||||
sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
|
||||
and the NDCG is obtained by dividing the DCG value on the ground truth set.
|
||||
In the current implementation, the relevance value is binary.
|
||||
|
||||
If a query has an empty ground truth set, zero will be used as ndcg together with
|
||||
a log warning.
|
||||
"""
|
||||
return self.call("ndcgAt", int(k))
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark import SparkContext
|
||||
|
|
Loading…
Reference in a new issue