[SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel

add RawPrediction as output column
add numClasses and numFeatures to OneVsRestModel

## What changes were proposed in this pull request?

- Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future

- Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction

## How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Lu WANG <lu.wang@databricks.com>

Closes #21044 from ludatabricks/SPARK-9312.
This commit is contained in:
Lu WANG 2018-04-16 11:27:30 -05:00 committed by Joseph K. Bradley
parent 083cf22356
commit 5003736ad6
2 changed files with 51 additions and 12 deletions

View file

@ -32,7 +32,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
import org.apache.spark.ml.util._
@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait {
/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams
private[ml] trait OneVsRestParams extends ClassifierParams
with ClassifierTypeTrait with HasWeightCol {
/**
@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] (
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
require(models.nonEmpty, "OneVsRestModel requires at least one model for one class")
@Since("2.4.0")
val numClasses: Int = models.length
@Since("2.4.0")
val numFeatures: Int = models.head.numFeatures
/** @group setParam */
@Since("2.1.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] (
@Since("2.1.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
/** @group setParam */
@Since("2.4.0")
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] (
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
predictions + ((index, prediction(1)))
}
model.setFeaturesCol($(featuresCol))
val transformedDataset = model.transform(df).select(columns: _*)
val updatedDataset = transformedDataset
@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] (
newDataset.unpersist()
}
// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (predictions: Map[Int, Double]) =>
predictions.maxBy(_._2)._1.toDouble
}
if (getRawPredictionCol != "") {
val numClass = models.length
// output label and label metadata as prediction
aggregatedDataset
.withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
.drop(accColName)
// output the RawPrediction as vector
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
val predArray = Array.fill[Double](numClass)(0.0)
predictions.foreach { case (idx, value) => predArray(idx) = value }
Vectors.dense(predArray)
}
// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble }
// output confidence as raw prediction, label and label metadata as prediction
aggregatedDataset
.withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
.drop(accColName)
} else {
// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (predictions: Map[Int, Double]) =>
predictions.maxBy(_._2)._1.toDouble
}
// output label and label metadata as prediction
aggregatedDataset
.withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata)
.drop(accColName)
}
}
@Since("1.4.1")
@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
/** @group setParam */
@Since("2.4.0")
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
/**
* The implementation of parallel one vs. rest runs the classification for
* each class in a separate threads.

View file

@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
.setClassifier(new LogisticRegression)
assert(ova.getLabelCol === "label")
assert(ova.getPredictionCol === "prediction")
assert(ova.getRawPredictionCol === "rawPrediction")
val ovaModel = ova.fit(dataset)
MLTestingUtils.checkCopyAndUids(ova, ovaModel)
assert(ovaModel.models.length === numClasses)
assert(ovaModel.numClasses === numClasses)
val transformedDataset = ovaModel.transform(dataset)
// check for label metadata in prediction col
@ -179,6 +180,7 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
ovaModel.setFeaturesCol("fea")
ovaModel.setPredictionCol("pred")
ovaModel.setRawPredictionCol("")
val transformedDataset = ovaModel.transform(dataset2)
val outputFields = transformedDataset.schema.fieldNames.toSet
assert(outputFields === Set("y", "fea", "pred"))
@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
val ovr = new OneVsRest()
.setClassifier(logReg)
val output = ovr.fit(dataset).transform(dataset)
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
assert(output.schema.fieldNames.toSet
=== Set("label", "features", "prediction", "rawPrediction"))
}
test("SPARK-21306: OneVsRest should support setWeightCol") {