[SPARK-7015] [MLLIB] [WIP] Multiclass to Binary Reduction: One Against All
initial cut of one against all. test code is a scaffolding , not fully implemented. This WIP is to gather early feedback. Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #5830 from harsha2010/reduction and squashes the following commits: 5f4b495 [Ram Sriharsha] Fix Test 386e98b [Ram Sriharsha] Style fix 49b4a17 [Ram Sriharsha] Simplify the test 02279cc [Ram Sriharsha] Output Label Metadata in Prediction Col bc78032 [Ram Sriharsha] Code Review Updates 8ce4845 [Ram Sriharsha] Merge with Master 2a807be [Ram Sriharsha] Merge branch 'master' into reduction e21bfcc [Ram Sriharsha] Style Fix 5614f23 [Ram Sriharsha] Style Fix c75583a [Ram Sriharsha] Cleanup 7a5f136 [Ram Sriharsha] Fix TODOs 804826b [Ram Sriharsha] Merge with Master 1448a5f [Ram Sriharsha] Style Fix 6e47807 [Ram Sriharsha] Style Fix d63e46b [Ram Sriharsha] Incorporate Code Review Feedback ced68b5 [Ram Sriharsha] Refactor OneVsAll to implement Predictor 78fa82a [Ram Sriharsha] extra line 0dfa1fb [Ram Sriharsha] Fix inexhaustive match cases that may arise from UnresolvedAttribute a59a4f4 [Ram Sriharsha] @Experimental 4167234 [Ram Sriharsha] Merge branch 'master' into reduction 868a4fd [Ram Sriharsha] @Experimental 041d905 [Ram Sriharsha] Code Review Fixes df188d8 [Ram Sriharsha] Style fix 612ec48 [Ram Sriharsha] Style Fix 6ef43d3 [Ram Sriharsha] Prefer Unresolved Attribute to Option: Java APIs are cleaner 6bf6bff [Ram Sriharsha] Update OneHotEncoder to new API e29cb89 [Ram Sriharsha] Merge branch 'master' into reduction 1c7fa44 [Ram Sriharsha] Fix Tests ca83672 [Ram Sriharsha] Incorporate Code Review Feedback + Rename to OneVsRestClassifier 221beeed [Ram Sriharsha] Upgrade to use Copy method for cloning Base Classifiers 26f1ddb [Ram Sriharsha] Merge with SPARK-5956 API changes 9738744 [Ram Sriharsha] Merge branch 'master' into reduction 1a3e375 [Ram Sriharsha] More efficient Implementation: Use withColumn to generate label column dynamically 32e0189 [Ram Sriharsha] Restrict reduction to Margin Based Classifiers ff272da [Ram Sriharsha] Style fix 28771f5 [Ram Sriharsha] Add Tests for Multiclass to Binary Reduction b60f874 [Ram Sriharsha] Fix Style issues in Test 3191cdf [Ram Sriharsha] Remove this test, accidental commit 23f056c [Ram Sriharsha] Fix Headers for test 1b5e929 [Ram Sriharsha] Fix Style issues and add Header 8752863 [Ram Sriharsha] [SPARK-7015][MLLib][WIP] Multiclass to Binary Reduction: One Against All
This commit is contained in:
parent
5438f49ccf
commit
595a67589a
|
@ -113,7 +113,8 @@ abstract class Predictor[
|
||||||
*
|
*
|
||||||
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
|
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
|
||||||
*/
|
*/
|
||||||
protected def featuresDataType: DataType = new VectorUDT
|
@DeveloperApi
|
||||||
|
private[ml] def featuresDataType: DataType = new VectorUDT
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
validateAndTransformSchema(schema, fitting = true, featuresDataType)
|
validateAndTransformSchema(schema, fitting = true, featuresDataType)
|
||||||
|
|
|
@ -123,6 +123,7 @@ class AttributeGroup private (
|
||||||
nominalMetadata += nominal.toMetadataImpl(withType = false)
|
nominalMetadata += nominal.toMetadataImpl(withType = false)
|
||||||
case binary: BinaryAttribute =>
|
case binary: BinaryAttribute =>
|
||||||
binaryMetadata += binary.toMetadataImpl(withType = false)
|
binaryMetadata += binary.toMetadataImpl(withType = false)
|
||||||
|
case UnresolvedAttribute =>
|
||||||
}
|
}
|
||||||
val attrBldr = new MetadataBuilder
|
val attrBldr = new MetadataBuilder
|
||||||
if (numericMetadata.nonEmpty) {
|
if (numericMetadata.nonEmpty) {
|
||||||
|
|
|
@ -43,6 +43,12 @@ object AttributeType {
|
||||||
Binary
|
Binary
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Unresolved type. */
|
||||||
|
val Unresolved: AttributeType = {
|
||||||
|
case object Unresolved extends AttributeType("unresolved")
|
||||||
|
Unresolved
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the [[AttributeType]] object from its name.
|
* Gets the [[AttributeType]] object from its name.
|
||||||
* @param name attribute type name: "numeric", "nominal", or "binary"
|
* @param name attribute type name: "numeric", "nominal", or "binary"
|
||||||
|
@ -54,6 +60,8 @@ object AttributeType {
|
||||||
Nominal
|
Nominal
|
||||||
} else if (name == Binary.name) {
|
} else if (name == Binary.name) {
|
||||||
Binary
|
Binary
|
||||||
|
} else if (name == Unresolved.name) {
|
||||||
|
Unresolved
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalArgumentException(s"Cannot recognize type $name.")
|
throw new IllegalArgumentException(s"Cannot recognize type $name.")
|
||||||
}
|
}
|
||||||
|
|
|
@ -125,7 +125,13 @@ private[attribute] trait AttributeFactory {
|
||||||
*/
|
*/
|
||||||
def fromStructField(field: StructField): Attribute = {
|
def fromStructField(field: StructField): Attribute = {
|
||||||
require(field.dataType == DoubleType)
|
require(field.dataType == DoubleType)
|
||||||
fromMetadata(field.metadata.getMetadata(AttributeKeys.ML_ATTR)).withName(field.name)
|
val metadata = field.metadata
|
||||||
|
val mlAttr = AttributeKeys.ML_ATTR
|
||||||
|
if (metadata.contains(mlAttr)) {
|
||||||
|
fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name)
|
||||||
|
} else {
|
||||||
|
UnresolvedAttribute
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -535,3 +541,32 @@ object BinaryAttribute extends AttributeFactory {
|
||||||
new BinaryAttribute(name, index, values)
|
new BinaryAttribute(name, index, values)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An unresolved attribute.
|
||||||
|
*/
|
||||||
|
object UnresolvedAttribute extends Attribute {
|
||||||
|
|
||||||
|
override def attrType: AttributeType = AttributeType.Unresolved
|
||||||
|
|
||||||
|
override def withIndex(index: Int): Attribute = this
|
||||||
|
|
||||||
|
override def isNumeric: Boolean = false
|
||||||
|
|
||||||
|
override def withoutIndex: Attribute = this
|
||||||
|
|
||||||
|
override def isNominal: Boolean = false
|
||||||
|
|
||||||
|
override def name: Option[String] = None
|
||||||
|
|
||||||
|
override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
|
||||||
|
Metadata.empty
|
||||||
|
}
|
||||||
|
|
||||||
|
override def withoutName: Attribute = this
|
||||||
|
|
||||||
|
override def index: Option[Int] = None
|
||||||
|
|
||||||
|
override def withName(name: String): Attribute = this
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import org.apache.spark.annotation.AlphaComponent
|
import org.apache.spark.annotation.AlphaComponent
|
||||||
import org.apache.spark.ml.{Estimator, Model}
|
import org.apache.spark.ml.{Estimator, Model}
|
||||||
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, NumericAttribute}
|
import org.apache.spark.ml.attribute._
|
||||||
import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
|
import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util.SchemaUtils
|
import org.apache.spark.ml.util.SchemaUtils
|
||||||
|
@ -375,6 +375,8 @@ class VectorIndexerModel private[ml] (
|
||||||
}
|
}
|
||||||
case (origAttr: Attribute, featAttr: NumericAttribute) =>
|
case (origAttr: Attribute, featAttr: NumericAttribute) =>
|
||||||
origAttr.withIndex(featAttr.index.get)
|
origAttr.withIndex(featAttr.index.get)
|
||||||
|
case (origAttr: Attribute, _) =>
|
||||||
|
origAttr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
partialFeatureAttributes
|
partialFeatureAttributes
|
||||||
|
|
|
@ -0,0 +1,211 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.reduction
|
||||||
|
|
||||||
|
import java.util.UUID
|
||||||
|
|
||||||
|
import scala.language.existentials
|
||||||
|
|
||||||
|
import org.apache.spark.annotation.{AlphaComponent, Experimental}
|
||||||
|
import org.apache.spark.ml._
|
||||||
|
import org.apache.spark.ml.attribute._
|
||||||
|
import org.apache.spark.ml.classification.{ClassificationModel, Classifier}
|
||||||
|
import org.apache.spark.ml.param.Param
|
||||||
|
import org.apache.spark.ml.util.MetadataUtils
|
||||||
|
import org.apache.spark.mllib.linalg.Vector
|
||||||
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
import org.apache.spark.sql.types._
|
||||||
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Params for [[OneVsRest]].
|
||||||
|
*/
|
||||||
|
private[ml] trait OneVsRestParams extends PredictorParams {
|
||||||
|
|
||||||
|
type ClassifierType = Classifier[F, E, M] forSome {
|
||||||
|
type F
|
||||||
|
type M <: ClassificationModel[F, M]
|
||||||
|
type E <: Classifier[F, E, M]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* param for the base binary classifier that we reduce multiclass classification into.
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
val classifier: Param[ClassifierType] =
|
||||||
|
new Param(this, "classifier", "base binary classifier ")
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
def getClassifier: ClassifierType = $(classifier)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Model produced by [[OneVsRest]].
|
||||||
|
* Stores the models resulting from training k different classifiers:
|
||||||
|
* one for each class.
|
||||||
|
* Each example is scored against all k models and the model with highest score
|
||||||
|
* is picked to label the example.
|
||||||
|
* TODO: API may need to change when we introduce a ClassificationModel trait as the public API
|
||||||
|
* @param parent
|
||||||
|
* @param labelMetadata Metadata of label column if it exists, or Nominal attribute
|
||||||
|
* representing the number of classes in training dataset otherwise.
|
||||||
|
* @param models the binary classification models for reduction.
|
||||||
|
* The i-th model is produced by testing the i-th class vs the rest.
|
||||||
|
*/
|
||||||
|
@AlphaComponent
|
||||||
|
class OneVsRestModel(
|
||||||
|
override val parent: OneVsRest,
|
||||||
|
labelMetadata: Metadata,
|
||||||
|
val models: Array[_ <: ClassificationModel[_,_]])
|
||||||
|
extends Model[OneVsRestModel] with OneVsRestParams {
|
||||||
|
|
||||||
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
|
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def transform(dataset: DataFrame): DataFrame = {
|
||||||
|
// Check schema
|
||||||
|
transformSchema(dataset.schema, logging = true)
|
||||||
|
|
||||||
|
// determine the input columns: these need to be passed through
|
||||||
|
val origCols = dataset.schema.map(f => col(f.name))
|
||||||
|
|
||||||
|
// add an accumulator column to store predictions of all the models
|
||||||
|
val accColName = "mbc$acc" + UUID.randomUUID().toString
|
||||||
|
val init: () => Map[Int, Double] = () => {Map()}
|
||||||
|
val mapType = MapType(IntegerType, DoubleType, false)
|
||||||
|
val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
|
||||||
|
|
||||||
|
// persist if underlying dataset is not persistent.
|
||||||
|
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
|
||||||
|
if (handlePersistence) {
|
||||||
|
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the accumulator column with the result of prediction of models
|
||||||
|
val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) {
|
||||||
|
case (df, (model, index)) => {
|
||||||
|
val rawPredictionCol = model.getRawPredictionCol
|
||||||
|
val columns = origCols ++ List(col(rawPredictionCol), col(accColName))
|
||||||
|
|
||||||
|
// add temporary column to store intermediate scores and update
|
||||||
|
val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
|
||||||
|
val update: (Map[Int, Double], Vector) => Map[Int, Double] =
|
||||||
|
(predictions: Map[Int, Double], prediction: Vector) => {
|
||||||
|
predictions + ((index, prediction(1)))
|
||||||
|
}
|
||||||
|
val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
|
||||||
|
val transformedDataset = model.transform(df).select(columns:_*)
|
||||||
|
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
|
||||||
|
val newColumns = origCols ++ List(col(tmpColName))
|
||||||
|
|
||||||
|
// switch out the intermediate column with the accumulator column
|
||||||
|
updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (handlePersistence) {
|
||||||
|
newDataset.unpersist()
|
||||||
|
}
|
||||||
|
|
||||||
|
// output the index of the classifier with highest confidence as prediction
|
||||||
|
val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => {
|
||||||
|
predictions.maxBy(_._2)._1.toDouble
|
||||||
|
}
|
||||||
|
|
||||||
|
// output label and label metadata as prediction
|
||||||
|
val labelUdf = callUDF(label, DoubleType, col(accColName))
|
||||||
|
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: Experimental ::
|
||||||
|
*
|
||||||
|
* Reduction of Multiclass Classification to Binary Classification.
|
||||||
|
* Performs reduction using one against all strategy.
|
||||||
|
* For a multiclass classification with k classes, train k models (one per class).
|
||||||
|
* Each example is scored against all k models and the model with highest score
|
||||||
|
* is picked to label the example.
|
||||||
|
*/
|
||||||
|
@Experimental
|
||||||
|
final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
|
||||||
|
|
||||||
|
/** @group setParam */
|
||||||
|
// TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed.
|
||||||
|
def setClassifier(value: Classifier[_,_,_]): this.type = {
|
||||||
|
set(classifier, value.asInstanceOf[ClassifierType])
|
||||||
|
}
|
||||||
|
|
||||||
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
|
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def fit(dataset: DataFrame): OneVsRestModel = {
|
||||||
|
// determine number of classes either from metadata if provided, or via computation.
|
||||||
|
val labelSchema = dataset.schema($(labelCol))
|
||||||
|
val computeNumClasses: () => Int = () => {
|
||||||
|
val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
|
||||||
|
// classes are assumed to be numbered from 0,...,maxLabelIndex
|
||||||
|
maxLabelIndex.toInt + 1
|
||||||
|
}
|
||||||
|
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
|
||||||
|
|
||||||
|
val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
|
||||||
|
|
||||||
|
// persist if underlying dataset is not persistent.
|
||||||
|
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
|
||||||
|
if (handlePersistence) {
|
||||||
|
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create k columns, one for each binary classifier.
|
||||||
|
val models = Range(0, numClasses).par.map { index =>
|
||||||
|
|
||||||
|
val label: Double => Double = (label: Double) => {
|
||||||
|
if (label.toInt == index) 1.0 else 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate new label metadata for the binary problem.
|
||||||
|
// TODO: use when ... otherwise after SPARK-7321 is merged
|
||||||
|
val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
|
||||||
|
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
|
||||||
|
val labelColName = "mc2b$" + index
|
||||||
|
val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
|
||||||
|
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
|
||||||
|
val classifier = getClassifier
|
||||||
|
classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
|
||||||
|
}.toArray[ClassificationModel[_,_]]
|
||||||
|
|
||||||
|
if (handlePersistence) {
|
||||||
|
multiclassLabeled.unpersist()
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract label metadata from label column if present, or create a nominal attribute
|
||||||
|
// to output the number of labels
|
||||||
|
val labelAttribute = Attribute.fromStructField(labelSchema) match {
|
||||||
|
case _: NumericAttribute | UnresolvedAttribute => {
|
||||||
|
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
|
||||||
|
}
|
||||||
|
case attr: Attribute => attr
|
||||||
|
}
|
||||||
|
copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,8 +20,7 @@ package org.apache.spark.ml.util
|
||||||
import scala.collection.immutable.HashMap
|
import scala.collection.immutable.HashMap
|
||||||
|
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute,
|
import org.apache.spark.ml.attribute._
|
||||||
NumericAttribute}
|
|
||||||
import org.apache.spark.sql.types.StructField
|
import org.apache.spark.sql.types.StructField
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,9 +38,9 @@ object MetadataUtils {
|
||||||
*/
|
*/
|
||||||
def getNumClasses(labelSchema: StructField): Option[Int] = {
|
def getNumClasses(labelSchema: StructField): Option[Int] = {
|
||||||
Attribute.fromStructField(labelSchema) match {
|
Attribute.fromStructField(labelSchema) match {
|
||||||
case numAttr: NumericAttribute => None
|
|
||||||
case binAttr: BinaryAttribute => Some(2)
|
case binAttr: BinaryAttribute => Some(2)
|
||||||
case nomAttr: NominalAttribute => nomAttr.getNumValues
|
case nomAttr: NominalAttribute => nomAttr.getNumValues
|
||||||
|
case _: NumericAttribute | UnresolvedAttribute => None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,7 +64,7 @@ object MetadataUtils {
|
||||||
Iterator()
|
Iterator()
|
||||||
} else {
|
} else {
|
||||||
attr match {
|
attr match {
|
||||||
case numAttr: NumericAttribute => Iterator()
|
case _: NumericAttribute | UnresolvedAttribute => Iterator()
|
||||||
case binAttr: BinaryAttribute => Iterator(idx -> 2)
|
case binAttr: BinaryAttribute => Iterator(idx -> 2)
|
||||||
case nomAttr: NominalAttribute =>
|
case nomAttr: NominalAttribute =>
|
||||||
nomAttr.getNumValues match {
|
nomAttr.getNumValues match {
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.reduction;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static scala.collection.JavaConversions.seqAsJavaList;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.ml.classification.LogisticRegression;
|
||||||
|
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
|
import org.apache.spark.sql.DataFrame;
|
||||||
|
import org.apache.spark.sql.SQLContext;
|
||||||
|
|
||||||
|
public class JavaOneVsRestSuite implements Serializable {
|
||||||
|
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
private transient SQLContext jsql;
|
||||||
|
private transient DataFrame dataset;
|
||||||
|
private transient JavaRDD<LabeledPoint> datasetRDD;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() {
|
||||||
|
jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
|
||||||
|
jsql = new SQLContext(jsc);
|
||||||
|
int nPoints = 3;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
|
||||||
|
* As a result, we are actually drawing samples from probability distribution of built model.
|
||||||
|
*/
|
||||||
|
double[] weights = {
|
||||||
|
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
|
||||||
|
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
|
||||||
|
|
||||||
|
double[] xMean = {5.843, 3.057, 3.758, 1.199};
|
||||||
|
double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
|
||||||
|
List<LabeledPoint> points = seqAsJavaList(generateMultinomialLogisticInput(
|
||||||
|
weights, xMean, xVariance, true, nPoints, 42));
|
||||||
|
datasetRDD = jsc.parallelize(points, 2);
|
||||||
|
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void tearDown() {
|
||||||
|
jsc.stop();
|
||||||
|
jsc = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void oneVsRestDefaultParams() {
|
||||||
|
OneVsRest ova = new OneVsRest();
|
||||||
|
ova.setClassifier(new LogisticRegression());
|
||||||
|
Assert.assertEquals(ova.getLabelCol() , "label");
|
||||||
|
Assert.assertEquals(ova.getPredictionCol() , "prediction");
|
||||||
|
OneVsRestModel ovaModel = ova.fit(dataset);
|
||||||
|
DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction");
|
||||||
|
predictions.collectAsList();
|
||||||
|
Assert.assertEquals(ovaModel.getLabelCol(), "label");
|
||||||
|
Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,7 +19,7 @@ package org.apache.spark.ml.attribute
|
||||||
|
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, Metadata}
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
class AttributeSuite extends FunSuite {
|
class AttributeSuite extends FunSuite {
|
||||||
|
|
||||||
|
@ -209,4 +209,12 @@ class AttributeSuite extends FunSuite {
|
||||||
intercept[IllegalArgumentException](attr.withName(""))
|
intercept[IllegalArgumentException](attr.withName(""))
|
||||||
intercept[IllegalArgumentException](attr.withIndex(-1))
|
intercept[IllegalArgumentException](attr.withIndex(-1))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("attribute from struct field") {
|
||||||
|
val metadata = NumericAttribute.defaultAttr.withName("label").toMetadata()
|
||||||
|
val fldWithoutMeta = new StructField("x", DoubleType, false, Metadata.empty)
|
||||||
|
assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
|
||||||
|
val fldWithMeta = new StructField("x", DoubleType, false, metadata)
|
||||||
|
assert(Attribute.fromStructField(fldWithMeta).isNumeric)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,113 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.reduction
|
||||||
|
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
import org.apache.spark.ml.attribute.NominalAttribute
|
||||||
|
import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
|
||||||
|
import org.apache.spark.ml.util.MetadataUtils
|
||||||
|
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
|
||||||
|
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|
||||||
|
import org.apache.spark.mllib.evaluation.MulticlassMetrics
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||||
|
|
||||||
|
class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
|
@transient var sqlContext: SQLContext = _
|
||||||
|
@transient var dataset: DataFrame = _
|
||||||
|
@transient var rdd: RDD[LabeledPoint] = _
|
||||||
|
|
||||||
|
override def beforeAll(): Unit = {
|
||||||
|
super.beforeAll()
|
||||||
|
sqlContext = new SQLContext(sc)
|
||||||
|
val nPoints = 1000
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
|
||||||
|
* As a result, we are actually drawing samples from probability distribution of built model.
|
||||||
|
*/
|
||||||
|
val weights = Array(
|
||||||
|
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
|
||||||
|
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
|
||||||
|
|
||||||
|
val xMean = Array(5.843, 3.057, 3.758, 1.199)
|
||||||
|
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
|
||||||
|
rdd = sc.parallelize(generateMultinomialLogisticInput(
|
||||||
|
weights, xMean, xVariance, true, nPoints, 42), 2)
|
||||||
|
dataset = sqlContext.createDataFrame(rdd)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("one-vs-rest: default params") {
|
||||||
|
val numClasses = 3
|
||||||
|
val ova = new OneVsRest()
|
||||||
|
ova.setClassifier(new LogisticRegression)
|
||||||
|
assert(ova.getLabelCol === "label")
|
||||||
|
assert(ova.getPredictionCol === "prediction")
|
||||||
|
val ovaModel = ova.fit(dataset)
|
||||||
|
assert(ovaModel.models.size === numClasses)
|
||||||
|
val transformedDataset = ovaModel.transform(dataset)
|
||||||
|
|
||||||
|
// check for label metadata in prediction col
|
||||||
|
val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol)
|
||||||
|
assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3))
|
||||||
|
|
||||||
|
val ovaResults = transformedDataset
|
||||||
|
.select("prediction", "label")
|
||||||
|
.map(row => (row.getDouble(0), row.getDouble(1)))
|
||||||
|
|
||||||
|
val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses)
|
||||||
|
lr.optimizer.setRegParam(0.1).setNumIterations(100)
|
||||||
|
|
||||||
|
val model = lr.run(rdd)
|
||||||
|
val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label))
|
||||||
|
// determine the #confusion matrix in each class.
|
||||||
|
// bound how much error we allow compared to multinomial logistic regression.
|
||||||
|
val expectedMetrics = new MulticlassMetrics(results)
|
||||||
|
val ovaMetrics = new MulticlassMetrics(ovaResults)
|
||||||
|
assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("one-vs-rest: pass label metadata correctly during train") {
|
||||||
|
val numClasses = 3
|
||||||
|
val ova = new OneVsRest()
|
||||||
|
ova.setClassifier(new MockLogisticRegression)
|
||||||
|
|
||||||
|
val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
|
||||||
|
val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata())
|
||||||
|
val features = dataset("features").as("features")
|
||||||
|
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
|
||||||
|
ova.fit(datasetWithLabelMetadata)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class MockLogisticRegression extends LogisticRegression {
|
||||||
|
|
||||||
|
setMaxIter(1)
|
||||||
|
|
||||||
|
override protected def train(dataset: DataFrame): LogisticRegressionModel = {
|
||||||
|
val labelSchema = dataset.schema($(labelCol))
|
||||||
|
// check for label attribute propagation.
|
||||||
|
assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
|
||||||
|
super.train(dataset)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue