[SPARK-1225, 1241] [MLLIB] Add AreaUnderCurve and BinaryClassificationMetrics
This PR implements a generic version of `AreaUnderCurve` using the `RDD.sliding` implementation from https://github.com/apache/spark/pull/136 . It also contains refactoring of https://github.com/apache/spark/pull/160 for binary classification evaluation. Author: Xiangrui Meng <meng@databricks.com> Closes #364 from mengxr/auc and squashes the following commits: a05941d [Xiangrui Meng] replace TP/FP/TN/FN by their full names 3f42e98 [Xiangrui Meng] add (0, 0), (1, 1) to roc, and (0, 1) to pr fb4b6d2 [Xiangrui Meng] rename Evaluator to Metrics and add more metrics b1b7dab [Xiangrui Meng] fix code styles 9dc3518 [Xiangrui Meng] add tests for BinaryClassificationEvaluator ca31da5 [Xiangrui Meng] remove PredictionAndResponse 3d71525 [Xiangrui Meng] move binary evalution classes to evaluation.binary 8f78958 [Xiangrui Meng] add PredictionAndResponse dda82d5 [Xiangrui Meng] add confusion matrix aa7e278 [Xiangrui Meng] add initial version of binary classification evaluator 221ebce [Xiangrui Meng] add a new test to sliding a920865 [Xiangrui Meng] Merge branch 'sliding' into auc a9b250a [Xiangrui Meng] move sliding to mllib cab9a52 [Xiangrui Meng] use last for the last element db6cb30 [Xiangrui Meng] remove unnecessary toSeq 9916202 [Xiangrui Meng] change RDD.sliding return type to RDD[Seq[T]] 284d991 [Xiangrui Meng] change SlidedRDD to SlidingRDD c1c6c22 [Xiangrui Meng] add AreaUnderCurve 65461b2 [Xiangrui Meng] Merge branch 'sliding' into auc 5ee6001 [Xiangrui Meng] add TODO d2a600d [Xiangrui Meng] add sliding to rdd
This commit is contained in:
parent
98225a6eff
commit
f5ace8da34
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.evaluation
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.mllib.rdd.RDDFunctions._
|
||||
|
||||
/**
|
||||
* Computes the area under the curve (AUC) using the trapezoidal rule.
|
||||
*/
|
||||
private[evaluation] object AreaUnderCurve {
|
||||
|
||||
/**
|
||||
* Uses the trapezoidal rule to compute the area under the line connecting the two input points.
|
||||
* @param points two 2D points stored in Seq
|
||||
*/
|
||||
private def trapezoid(points: Seq[(Double, Double)]): Double = {
|
||||
require(points.length == 2)
|
||||
val x = points.head
|
||||
val y = points.last
|
||||
(y._1 - x._1) * (y._2 + x._2) / 2.0
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the area under the given curve.
|
||||
*
|
||||
* @param curve a RDD of ordered 2D points stored in pairs representing a curve
|
||||
*/
|
||||
def of(curve: RDD[(Double, Double)]): Double = {
|
||||
curve.sliding(2).aggregate(0.0)(
|
||||
seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
|
||||
combOp = _ + _
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the area under the given curve.
|
||||
*
|
||||
* @param curve an iterator over ordered 2D points stored in pairs representing a curve
|
||||
*/
|
||||
def of(curve: Iterable[(Double, Double)]): Double = {
|
||||
curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)(
|
||||
seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
|
||||
combop = _ + _
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.evaluation.binary
|
||||
|
||||
/**
|
||||
* Trait for a binary classification evaluation metric computer.
|
||||
*/
|
||||
private[evaluation] trait BinaryClassificationMetricComputer extends Serializable {
|
||||
def apply(c: BinaryConfusionMatrix): Double
|
||||
}
|
||||
|
||||
/** Precision. */
|
||||
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
|
||||
override def apply(c: BinaryConfusionMatrix): Double =
|
||||
c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
|
||||
}
|
||||
|
||||
/** False positive rate. */
|
||||
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
|
||||
override def apply(c: BinaryConfusionMatrix): Double =
|
||||
c.numFalsePositives.toDouble / c.numNegatives
|
||||
}
|
||||
|
||||
/** Recall. */
|
||||
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
|
||||
override def apply(c: BinaryConfusionMatrix): Double =
|
||||
c.numTruePositives.toDouble / c.numPositives
|
||||
}
|
||||
|
||||
/**
|
||||
* F-Measure.
|
||||
* @param beta the beta constant in F-Measure
|
||||
* @see http://en.wikipedia.org/wiki/F1_score
|
||||
*/
|
||||
private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificationMetricComputer {
|
||||
private val beta2 = beta * beta
|
||||
override def apply(c: BinaryConfusionMatrix): Double = {
|
||||
val precision = Precision(c)
|
||||
val recall = Recall(c)
|
||||
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,204 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.evaluation.binary
|
||||
|
||||
import org.apache.spark.rdd.{UnionRDD, RDD}
|
||||
import org.apache.spark.SparkContext._
|
||||
import org.apache.spark.mllib.evaluation.AreaUnderCurve
|
||||
import org.apache.spark.Logging
|
||||
|
||||
/**
|
||||
* Implementation of [[org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix]].
|
||||
*
|
||||
* @param count label counter for labels with scores greater than or equal to the current score
|
||||
* @param totalCount label counter for all labels
|
||||
*/
|
||||
private case class BinaryConfusionMatrixImpl(
|
||||
count: LabelCounter,
|
||||
totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {
|
||||
|
||||
/** number of true positives */
|
||||
override def numTruePositives: Long = count.numPositives
|
||||
|
||||
/** number of false positives */
|
||||
override def numFalsePositives: Long = count.numNegatives
|
||||
|
||||
/** number of false negatives */
|
||||
override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives
|
||||
|
||||
/** number of true negatives */
|
||||
override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives
|
||||
|
||||
/** number of positives */
|
||||
override def numPositives: Long = totalCount.numPositives
|
||||
|
||||
/** number of negatives */
|
||||
override def numNegatives: Long = totalCount.numNegatives
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluator for binary classification.
|
||||
*
|
||||
* @param scoreAndLabels an RDD of (score, label) pairs.
|
||||
*/
|
||||
class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])
|
||||
extends Serializable with Logging {
|
||||
|
||||
private lazy val (
|
||||
cumulativeCounts: RDD[(Double, LabelCounter)],
|
||||
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
|
||||
// Create a bin for each distinct score value, count positives and negatives within each bin,
|
||||
// and then sort by score values in descending order.
|
||||
val counts = scoreAndLabels.combineByKey(
|
||||
createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label,
|
||||
mergeValue = (c: LabelCounter, label: Double) => c += label,
|
||||
mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2
|
||||
).sortByKey(ascending = false)
|
||||
val agg = counts.values.mapPartitions({ iter =>
|
||||
val agg = new LabelCounter()
|
||||
iter.foreach(agg += _)
|
||||
Iterator(agg)
|
||||
}, preservesPartitioning = true).collect()
|
||||
val partitionwiseCumulativeCounts =
|
||||
agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c)
|
||||
val totalCount = partitionwiseCumulativeCounts.last
|
||||
logInfo(s"Total counts: $totalCount")
|
||||
val cumulativeCounts = counts.mapPartitionsWithIndex(
|
||||
(index: Int, iter: Iterator[(Double, LabelCounter)]) => {
|
||||
val cumCount = partitionwiseCumulativeCounts(index)
|
||||
iter.map { case (score, c) =>
|
||||
cumCount += c
|
||||
(score, cumCount.clone())
|
||||
}
|
||||
}, preservesPartitioning = true)
|
||||
cumulativeCounts.persist()
|
||||
val confusions = cumulativeCounts.map { case (score, cumCount) =>
|
||||
(score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
|
||||
}
|
||||
(cumulativeCounts, confusions)
|
||||
}
|
||||
|
||||
/** Unpersist intermediate RDDs used in the computation. */
|
||||
def unpersist() {
|
||||
cumulativeCounts.unpersist()
|
||||
}
|
||||
|
||||
/** Returns thresholds in descending order. */
|
||||
def thresholds(): RDD[Double] = cumulativeCounts.map(_._1)
|
||||
|
||||
/**
|
||||
* Returns the receiver operating characteristic (ROC) curve,
|
||||
* which is an RDD of (false positive rate, true positive rate)
|
||||
* with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
|
||||
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
|
||||
*/
|
||||
def roc(): RDD[(Double, Double)] = {
|
||||
val rocCurve = createCurve(FalsePositiveRate, Recall)
|
||||
val sc = confusions.context
|
||||
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
|
||||
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
|
||||
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the area under the receiver operating characteristic (ROC) curve.
|
||||
*/
|
||||
def areaUnderROC(): Double = AreaUnderCurve.of(roc())
|
||||
|
||||
/**
|
||||
* Returns the precision-recall curve, which is an RDD of (recall, precision),
|
||||
* NOT (precision, recall), with (0.0, 1.0) prepended to it.
|
||||
* @see http://en.wikipedia.org/wiki/Precision_and_recall
|
||||
*/
|
||||
def pr(): RDD[(Double, Double)] = {
|
||||
val prCurve = createCurve(Recall, Precision)
|
||||
val sc = confusions.context
|
||||
val first = sc.makeRDD(Seq((0.0, 1.0)), 1)
|
||||
first.union(prCurve)
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the area under the precision-recall curve.
|
||||
*/
|
||||
def areaUnderPR(): Double = AreaUnderCurve.of(pr())
|
||||
|
||||
/**
|
||||
* Returns the (threshold, F-Measure) curve.
|
||||
* @param beta the beta factor in F-Measure computation.
|
||||
* @return an RDD of (threshold, F-Measure) pairs.
|
||||
* @see http://en.wikipedia.org/wiki/F1_score
|
||||
*/
|
||||
def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))
|
||||
|
||||
/** Returns the (threshold, F-Measure) curve with beta = 1.0. */
|
||||
def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0)
|
||||
|
||||
/** Returns the (threshold, precision) curve. */
|
||||
def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision)
|
||||
|
||||
/** Returns the (threshold, recall) curve. */
|
||||
def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall)
|
||||
|
||||
/** Creates a curve of (threshold, metric). */
|
||||
private def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {
|
||||
confusions.map { case (s, c) =>
|
||||
(s, y(c))
|
||||
}
|
||||
}
|
||||
|
||||
/** Creates a curve of (metricX, metricY). */
|
||||
private def createCurve(
|
||||
x: BinaryClassificationMetricComputer,
|
||||
y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {
|
||||
confusions.map { case (_, c) =>
|
||||
(x(c), y(c))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A counter for positives and negatives.
|
||||
*
|
||||
* @param numPositives number of positive labels
|
||||
* @param numNegatives number of negative labels
|
||||
*/
|
||||
private class LabelCounter(
|
||||
var numPositives: Long = 0L,
|
||||
var numNegatives: Long = 0L) extends Serializable {
|
||||
|
||||
/** Processes a label. */
|
||||
def +=(label: Double): LabelCounter = {
|
||||
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
|
||||
// -1.0 for negative as well.
|
||||
if (label > 0.5) numPositives += 1L else numNegatives += 1L
|
||||
this
|
||||
}
|
||||
|
||||
/** Merges another counter. */
|
||||
def +=(other: LabelCounter): LabelCounter = {
|
||||
numPositives += other.numPositives
|
||||
numNegatives += other.numNegatives
|
||||
this
|
||||
}
|
||||
|
||||
override def clone: LabelCounter = {
|
||||
new LabelCounter(numPositives, numNegatives)
|
||||
}
|
||||
|
||||
override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.evaluation.binary
|
||||
|
||||
/**
|
||||
* Trait for a binary confusion matrix.
|
||||
*/
|
||||
private[evaluation] trait BinaryConfusionMatrix {
|
||||
/** number of true positives */
|
||||
def numTruePositives: Long
|
||||
|
||||
/** number of false positives */
|
||||
def numFalsePositives: Long
|
||||
|
||||
/** number of false negatives */
|
||||
def numFalseNegatives: Long
|
||||
|
||||
/** number of true negatives */
|
||||
def numTrueNegatives: Long
|
||||
|
||||
/** number of positives */
|
||||
def numPositives: Long = numTruePositives + numFalseNegatives
|
||||
|
||||
/** number of negatives */
|
||||
def numNegatives: Long = numFalsePositives + numTrueNegatives
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.rdd
|
||||
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
/**
|
||||
* Machine learning specific RDD functions.
|
||||
*/
|
||||
private[mllib]
|
||||
class RDDFunctions[T: ClassTag](self: RDD[T]) {
|
||||
|
||||
/**
|
||||
* Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
|
||||
* window over them. The ordering is first based on the partition index and then the ordering of
|
||||
* items within each partition. This is similar to sliding in Scala collections, except that it
|
||||
* becomes an empty RDD if the window size is greater than the total number of items. It needs to
|
||||
* trigger a Spark job if the parent RDD has more than one partitions and the window size is
|
||||
* greater than 1.
|
||||
*/
|
||||
def sliding(windowSize: Int): RDD[Seq[T]] = {
|
||||
require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.")
|
||||
if (windowSize == 1) {
|
||||
self.map(Seq(_))
|
||||
} else {
|
||||
new SlidingRDD[T](self, windowSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[mllib]
|
||||
object RDDFunctions {
|
||||
|
||||
/** Implicit conversion from an RDD to RDDFunctions. */
|
||||
implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd)
|
||||
}
|
104
mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
Normal file
104
mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
Normal file
|
@ -0,0 +1,104 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.rdd
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.{TaskContext, Partition}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
private[mllib]
|
||||
class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T])
|
||||
extends Partition with Serializable {
|
||||
override val index: Int = idx
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
|
||||
* window over them. The ordering is first based on the partition index and then the ordering of
|
||||
* items within each partition. This is similar to sliding in Scala collections, except that it
|
||||
* becomes an empty RDD if the window size is greater than the total number of items. It needs to
|
||||
* trigger a Spark job if the parent RDD has more than one partitions. To make this operation
|
||||
* efficient, the number of items per partition should be larger than the window size and the
|
||||
* window size should be small, e.g., 2.
|
||||
*
|
||||
* @param parent the parent RDD
|
||||
* @param windowSize the window size, must be greater than 1
|
||||
*
|
||||
* @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]]
|
||||
*/
|
||||
private[mllib]
|
||||
class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int)
|
||||
extends RDD[Seq[T]](parent) {
|
||||
|
||||
require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.")
|
||||
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = {
|
||||
val part = split.asInstanceOf[SlidingRDDPartition[T]]
|
||||
(firstParent[T].iterator(part.prev, context) ++ part.tail)
|
||||
.sliding(windowSize)
|
||||
.withPartial(false)
|
||||
}
|
||||
|
||||
override def getPreferredLocations(split: Partition): Seq[String] =
|
||||
firstParent[T].preferredLocations(split.asInstanceOf[SlidingRDDPartition[T]].prev)
|
||||
|
||||
override def getPartitions: Array[Partition] = {
|
||||
val parentPartitions = parent.partitions
|
||||
val n = parentPartitions.size
|
||||
if (n == 0) {
|
||||
Array.empty
|
||||
} else if (n == 1) {
|
||||
Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty))
|
||||
} else {
|
||||
val n1 = n - 1
|
||||
val w1 = windowSize - 1
|
||||
// Get the first w1 items of each partition, starting from the second partition.
|
||||
val nextHeads =
|
||||
parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true)
|
||||
val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]()
|
||||
var i = 0
|
||||
var partitionIndex = 0
|
||||
while (i < n1) {
|
||||
var j = i
|
||||
val tail = mutable.ListBuffer[T]()
|
||||
// Keep appending to the current tail until appended a head of size w1.
|
||||
while (j < n1 && nextHeads(j).size < w1) {
|
||||
tail ++= nextHeads(j)
|
||||
j += 1
|
||||
}
|
||||
if (j < n1) {
|
||||
tail ++= nextHeads(j)
|
||||
j += 1
|
||||
}
|
||||
partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail)
|
||||
partitionIndex += 1
|
||||
// Skip appended heads.
|
||||
i = j
|
||||
}
|
||||
// If the head of last partition has size w1, we also need to add this partition.
|
||||
if (nextHeads.last.size == w1) {
|
||||
partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty)
|
||||
}
|
||||
partitions.toArray
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Override methods such as aggregate, which only requires one Spark job.
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.evaluation
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.mllib.util.LocalSparkContext
|
||||
|
||||
class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
|
||||
test("auc computation") {
|
||||
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
|
||||
val auc = 4.0
|
||||
assert(AreaUnderCurve.of(curve) === auc)
|
||||
val rddCurve = sc.parallelize(curve, 2)
|
||||
assert(AreaUnderCurve.of(rddCurve) == auc)
|
||||
}
|
||||
|
||||
test("auc of an empty curve") {
|
||||
val curve = Seq.empty[(Double, Double)]
|
||||
assert(AreaUnderCurve.of(curve) === 0.0)
|
||||
val rddCurve = sc.parallelize(curve, 2)
|
||||
assert(AreaUnderCurve.of(rddCurve) === 0.0)
|
||||
}
|
||||
|
||||
test("auc of a curve with a single point") {
|
||||
val curve = Seq((1.0, 1.0))
|
||||
assert(AreaUnderCurve.of(curve) === 0.0)
|
||||
val rddCurve = sc.parallelize(curve, 2)
|
||||
assert(AreaUnderCurve.of(rddCurve) === 0.0)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.evaluation.binary
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.mllib.util.LocalSparkContext
|
||||
import org.apache.spark.mllib.evaluation.AreaUnderCurve
|
||||
|
||||
class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
|
||||
test("binary evaluation metrics") {
|
||||
val scoreAndLabels = sc.parallelize(
|
||||
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
|
||||
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
|
||||
val threshold = Seq(0.8, 0.6, 0.4, 0.1)
|
||||
val numTruePositives = Seq(1, 3, 3, 4)
|
||||
val numFalsePositives = Seq(0, 1, 2, 3)
|
||||
val numPositives = 4
|
||||
val numNegatives = 3
|
||||
val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
|
||||
t.toDouble / (t + f)
|
||||
}
|
||||
val recall = numTruePositives.map(t => t.toDouble / numPositives)
|
||||
val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
|
||||
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
|
||||
val pr = recall.zip(precision)
|
||||
val prCurve = Seq((0.0, 1.0)) ++ pr
|
||||
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) }
|
||||
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
|
||||
assert(metrics.thresholds().collect().toSeq === threshold)
|
||||
assert(metrics.roc().collect().toSeq === rocCurve)
|
||||
assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve))
|
||||
assert(metrics.pr().collect().toSeq === prCurve)
|
||||
assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve))
|
||||
assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1))
|
||||
assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2))
|
||||
assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision))
|
||||
assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.rdd
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.mllib.util.LocalSparkContext
|
||||
import org.apache.spark.mllib.rdd.RDDFunctions._
|
||||
|
||||
class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
test("sliding") {
|
||||
val data = 0 until 6
|
||||
for (numPartitions <- 1 to 8) {
|
||||
val rdd = sc.parallelize(data, numPartitions)
|
||||
for (windowSize <- 1 to 6) {
|
||||
val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList
|
||||
val expected = data.sliding(windowSize).map(_.toList).toList
|
||||
assert(sliding === expected)
|
||||
}
|
||||
assert(rdd.sliding(7).collect().isEmpty,
|
||||
"Should return an empty RDD if the window size is greater than the number of items.")
|
||||
}
|
||||
}
|
||||
|
||||
test("sliding with empty partitions") {
|
||||
val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7))
|
||||
val rdd = sc.parallelize(data, data.length).flatMap(s => s)
|
||||
assert(rdd.partitions.size === data.length)
|
||||
val sliding = rdd.sliding(3)
|
||||
val expected = data.flatMap(x => x).sliding(3).toList
|
||||
assert(sliding.collect().toList === expected)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue