[SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel
add RawPrediction as output column add numClasses and numFeatures to OneVsRestModel ## What changes were proposed in this pull request? - Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future - Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG <lu.wang@databricks.com> Closes #21044 from ludatabricks/SPARK-9312.
This commit is contained in:
parent
083cf22356
commit
5003736ad6
|
@ -32,7 +32,7 @@ import org.apache.spark.SparkContext
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml._
|
||||
import org.apache.spark.ml.attribute._
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
|
||||
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
|
||||
import org.apache.spark.ml.util._
|
||||
|
@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait {
|
|||
/**
|
||||
* Params for [[OneVsRest]].
|
||||
*/
|
||||
private[ml] trait OneVsRestParams extends PredictorParams
|
||||
private[ml] trait OneVsRestParams extends ClassifierParams
|
||||
with ClassifierTypeTrait with HasWeightCol {
|
||||
|
||||
/**
|
||||
|
@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] (
|
|||
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
|
||||
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
|
||||
|
||||
require(models.nonEmpty, "OneVsRestModel requires at least one model for one class")
|
||||
|
||||
@Since("2.4.0")
|
||||
val numClasses: Int = models.length
|
||||
|
||||
@Since("2.4.0")
|
||||
val numFeatures: Int = models.head.numFeatures
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.1.0")
|
||||
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
|
||||
|
@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] (
|
|||
@Since("2.1.0")
|
||||
def setPredictionCol(value: String): this.type = set(predictionCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.4.0")
|
||||
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
|
||||
|
||||
@Since("1.4.0")
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
|
||||
|
@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] (
|
|||
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
|
||||
predictions + ((index, prediction(1)))
|
||||
}
|
||||
|
||||
model.setFeaturesCol($(featuresCol))
|
||||
val transformedDataset = model.transform(df).select(columns: _*)
|
||||
val updatedDataset = transformedDataset
|
||||
|
@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] (
|
|||
newDataset.unpersist()
|
||||
}
|
||||
|
||||
// output the index of the classifier with highest confidence as prediction
|
||||
val labelUDF = udf { (predictions: Map[Int, Double]) =>
|
||||
predictions.maxBy(_._2)._1.toDouble
|
||||
}
|
||||
if (getRawPredictionCol != "") {
|
||||
val numClass = models.length
|
||||
|
||||
// output label and label metadata as prediction
|
||||
aggregatedDataset
|
||||
.withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
|
||||
.drop(accColName)
|
||||
// output the RawPrediction as vector
|
||||
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
|
||||
val predArray = Array.fill[Double](numClass)(0.0)
|
||||
predictions.foreach { case (idx, value) => predArray(idx) = value }
|
||||
Vectors.dense(predArray)
|
||||
}
|
||||
|
||||
// output the index of the classifier with highest confidence as prediction
|
||||
val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble }
|
||||
|
||||
// output confidence as raw prediction, label and label metadata as prediction
|
||||
aggregatedDataset
|
||||
.withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
|
||||
.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
|
||||
.drop(accColName)
|
||||
} else {
|
||||
// output the index of the classifier with highest confidence as prediction
|
||||
val labelUDF = udf { (predictions: Map[Int, Double]) =>
|
||||
predictions.maxBy(_._2)._1.toDouble
|
||||
}
|
||||
// output label and label metadata as prediction
|
||||
aggregatedDataset
|
||||
.withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata)
|
||||
.drop(accColName)
|
||||
}
|
||||
}
|
||||
|
||||
@Since("1.4.1")
|
||||
|
@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") (
|
|||
@Since("1.5.0")
|
||||
def setPredictionCol(value: String): this.type = set(predictionCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.4.0")
|
||||
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
|
||||
|
||||
/**
|
||||
* The implementation of parallel one vs. rest runs the classification for
|
||||
* each class in a separate threads.
|
||||
|
|
|
@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
|
|||
.setClassifier(new LogisticRegression)
|
||||
assert(ova.getLabelCol === "label")
|
||||
assert(ova.getPredictionCol === "prediction")
|
||||
assert(ova.getRawPredictionCol === "rawPrediction")
|
||||
val ovaModel = ova.fit(dataset)
|
||||
|
||||
MLTestingUtils.checkCopyAndUids(ova, ovaModel)
|
||||
|
||||
assert(ovaModel.models.length === numClasses)
|
||||
assert(ovaModel.numClasses === numClasses)
|
||||
val transformedDataset = ovaModel.transform(dataset)
|
||||
|
||||
// check for label metadata in prediction col
|
||||
|
@ -179,6 +180,7 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
|
|||
val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
|
||||
ovaModel.setFeaturesCol("fea")
|
||||
ovaModel.setPredictionCol("pred")
|
||||
ovaModel.setRawPredictionCol("")
|
||||
val transformedDataset = ovaModel.transform(dataset2)
|
||||
val outputFields = transformedDataset.schema.fieldNames.toSet
|
||||
assert(outputFields === Set("y", "fea", "pred"))
|
||||
|
@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
|
|||
val ovr = new OneVsRest()
|
||||
.setClassifier(logReg)
|
||||
val output = ovr.fit(dataset).transform(dataset)
|
||||
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
|
||||
assert(output.schema.fieldNames.toSet
|
||||
=== Set("label", "features", "prediction", "rawPrediction"))
|
||||
}
|
||||
|
||||
test("SPARK-21306: OneVsRest should support setWeightCol") {
|
||||
|
|
Loading…
Reference in a new issue