[SPARK-7015] [MLLIB] [WIP] Multiclass to Binary Reduction: One Against All

initial cut of one against all. test code is a scaffolding , not fully implemented.
This WIP is to gather early feedback.

Author: Ram Sriharsha <rsriharsha@hw11853.local>

Closes #5830 from harsha2010/reduction and squashes the following commits:

5f4b495 [Ram Sriharsha] Fix Test
386e98b [Ram Sriharsha] Style fix
49b4a17 [Ram Sriharsha] Simplify the test
02279cc [Ram Sriharsha] Output Label Metadata in Prediction Col
bc78032 [Ram Sriharsha] Code Review Updates
8ce4845 [Ram Sriharsha] Merge with Master
2a807be [Ram Sriharsha] Merge branch 'master' into reduction
e21bfcc [Ram Sriharsha] Style Fix
5614f23 [Ram Sriharsha] Style Fix
c75583a [Ram Sriharsha] Cleanup
7a5f136 [Ram Sriharsha] Fix TODOs
804826b [Ram Sriharsha] Merge with Master
1448a5f [Ram Sriharsha] Style Fix
6e47807 [Ram Sriharsha] Style Fix
d63e46b [Ram Sriharsha] Incorporate Code Review Feedback
ced68b5 [Ram Sriharsha] Refactor OneVsAll to implement Predictor
78fa82a [Ram Sriharsha] extra line
0dfa1fb [Ram Sriharsha] Fix inexhaustive match cases that may arise from UnresolvedAttribute
a59a4f4 [Ram Sriharsha] @Experimental
4167234 [Ram Sriharsha] Merge branch 'master' into reduction
868a4fd [Ram Sriharsha] @Experimental
041d905 [Ram Sriharsha] Code Review Fixes
df188d8 [Ram Sriharsha] Style fix
612ec48 [Ram Sriharsha] Style Fix
6ef43d3 [Ram Sriharsha] Prefer Unresolved Attribute to Option: Java APIs are cleaner
6bf6bff [Ram Sriharsha] Update OneHotEncoder to new API
e29cb89 [Ram Sriharsha] Merge branch 'master' into reduction
1c7fa44 [Ram Sriharsha] Fix Tests
ca83672 [Ram Sriharsha] Incorporate Code Review Feedback + Rename to OneVsRestClassifier
221beeed [Ram Sriharsha] Upgrade to use Copy method for cloning Base Classifiers
26f1ddb [Ram Sriharsha] Merge with SPARK-5956 API changes
9738744 [Ram Sriharsha] Merge branch 'master' into reduction
1a3e375 [Ram Sriharsha] More efficient Implementation: Use withColumn to generate label column dynamically
32e0189 [Ram Sriharsha] Restrict reduction to Margin Based Classifiers
ff272da [Ram Sriharsha] Style fix
28771f5 [Ram Sriharsha] Add Tests for Multiclass to Binary Reduction
b60f874 [Ram Sriharsha] Fix Style issues in Test
3191cdf [Ram Sriharsha] Remove this test, accidental commit
23f056c [Ram Sriharsha] Fix Headers for test
1b5e929 [Ram Sriharsha] Fix Style issues and add Header
8752863 [Ram Sriharsha] [SPARK-7015][MLLib][WIP] Multiclass to Binary Reduction: One Against All
This commit is contained in:
Ram Sriharsha 2015-05-12 13:35:12 -07:00 committed by Joseph K. Bradley
parent 5438f49ccf
commit 595a67589a
10 changed files with 471 additions and 8 deletions

View file

@ -113,7 +113,8 @@ abstract class Predictor[
* *
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
*/ */
protected def featuresDataType: DataType = new VectorUDT @DeveloperApi
private[ml] def featuresDataType: DataType = new VectorUDT
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, featuresDataType) validateAndTransformSchema(schema, fitting = true, featuresDataType)

View file

@ -123,6 +123,7 @@ class AttributeGroup private (
nominalMetadata += nominal.toMetadataImpl(withType = false) nominalMetadata += nominal.toMetadataImpl(withType = false)
case binary: BinaryAttribute => case binary: BinaryAttribute =>
binaryMetadata += binary.toMetadataImpl(withType = false) binaryMetadata += binary.toMetadataImpl(withType = false)
case UnresolvedAttribute =>
} }
val attrBldr = new MetadataBuilder val attrBldr = new MetadataBuilder
if (numericMetadata.nonEmpty) { if (numericMetadata.nonEmpty) {

View file

@ -43,6 +43,12 @@ object AttributeType {
Binary Binary
} }
/** Unresolved type. */
val Unresolved: AttributeType = {
case object Unresolved extends AttributeType("unresolved")
Unresolved
}
/** /**
* Gets the [[AttributeType]] object from its name. * Gets the [[AttributeType]] object from its name.
* @param name attribute type name: "numeric", "nominal", or "binary" * @param name attribute type name: "numeric", "nominal", or "binary"
@ -54,6 +60,8 @@ object AttributeType {
Nominal Nominal
} else if (name == Binary.name) { } else if (name == Binary.name) {
Binary Binary
} else if (name == Unresolved.name) {
Unresolved
} else { } else {
throw new IllegalArgumentException(s"Cannot recognize type $name.") throw new IllegalArgumentException(s"Cannot recognize type $name.")
} }

View file

@ -125,7 +125,13 @@ private[attribute] trait AttributeFactory {
*/ */
def fromStructField(field: StructField): Attribute = { def fromStructField(field: StructField): Attribute = {
require(field.dataType == DoubleType) require(field.dataType == DoubleType)
fromMetadata(field.metadata.getMetadata(AttributeKeys.ML_ATTR)).withName(field.name) val metadata = field.metadata
val mlAttr = AttributeKeys.ML_ATTR
if (metadata.contains(mlAttr)) {
fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name)
} else {
UnresolvedAttribute
}
} }
} }
@ -535,3 +541,32 @@ object BinaryAttribute extends AttributeFactory {
new BinaryAttribute(name, index, values) new BinaryAttribute(name, index, values)
} }
} }
/**
* An unresolved attribute.
*/
object UnresolvedAttribute extends Attribute {
override def attrType: AttributeType = AttributeType.Unresolved
override def withIndex(index: Int): Attribute = this
override def isNumeric: Boolean = false
override def withoutIndex: Attribute = this
override def isNominal: Boolean = false
override def name: Option[String] = None
override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
Metadata.empty
}
override def withoutName: Attribute = this
override def index: Option[Int] = None
override def withName(name: String): Attribute = this
}

View file

@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, NumericAttribute} import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.util.SchemaUtils
@ -375,6 +375,8 @@ class VectorIndexerModel private[ml] (
} }
case (origAttr: Attribute, featAttr: NumericAttribute) => case (origAttr: Attribute, featAttr: NumericAttribute) =>
origAttr.withIndex(featAttr.index.get) origAttr.withIndex(featAttr.index.get)
case (origAttr: Attribute, _) =>
origAttr
} }
} else { } else {
partialFeatureAttributes partialFeatureAttributes

View file

@ -0,0 +1,211 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.reduction
import java.util.UUID
import scala.language.existentials
import org.apache.spark.annotation.{AlphaComponent, Experimental}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{ClassificationModel, Classifier}
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams {
type ClassifierType = Classifier[F, E, M] forSome {
type F
type M <: ClassificationModel[F, M]
type E <: Classifier[F, E, M]
}
/**
* param for the base binary classifier that we reduce multiclass classification into.
* @group param
*/
val classifier: Param[ClassifierType] =
new Param(this, "classifier", "base binary classifier ")
/** @group getParam */
def getClassifier: ClassifierType = $(classifier)
}
/**
* Model produced by [[OneVsRest]].
* Stores the models resulting from training k different classifiers:
* one for each class.
* Each example is scored against all k models and the model with highest score
* is picked to label the example.
* TODO: API may need to change when we introduce a ClassificationModel trait as the public API
* @param parent
* @param labelMetadata Metadata of label column if it exists, or Nominal attribute
* representing the number of classes in training dataset otherwise.
* @param models the binary classification models for reduction.
* The i-th model is produced by testing the i-th class vs the rest.
*/
@AlphaComponent
class OneVsRestModel(
override val parent: OneVsRest,
labelMetadata: Metadata,
val models: Array[_ <: ClassificationModel[_,_]])
extends Model[OneVsRestModel] with OneVsRestParams {
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
}
override def transform(dataset: DataFrame): DataFrame = {
// Check schema
transformSchema(dataset.schema, logging = true)
// determine the input columns: these need to be passed through
val origCols = dataset.schema.map(f => col(f.name))
// add an accumulator column to store predictions of all the models
val accColName = "mbc$acc" + UUID.randomUUID().toString
val init: () => Map[Int, Double] = () => {Map()}
val mapType = MapType(IntegerType, DoubleType, false)
val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
}
// update the accumulator column with the result of prediction of models
val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) {
case (df, (model, index)) => {
val rawPredictionCol = model.getRawPredictionCol
val columns = origCols ++ List(col(rawPredictionCol), col(accColName))
// add temporary column to store intermediate scores and update
val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
val update: (Map[Int, Double], Vector) => Map[Int, Double] =
(predictions: Map[Int, Double], prediction: Vector) => {
predictions + ((index, prediction(1)))
}
val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
val transformedDataset = model.transform(df).select(columns:_*)
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
val newColumns = origCols ++ List(col(tmpColName))
// switch out the intermediate column with the accumulator column
updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
}
}
if (handlePersistence) {
newDataset.unpersist()
}
// output the index of the classifier with highest confidence as prediction
val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => {
predictions.maxBy(_._2)._1.toDouble
}
// output label and label metadata as prediction
val labelUdf = callUDF(label, DoubleType, col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
}
}
/**
* :: Experimental ::
*
* Reduction of Multiclass Classification to Binary Classification.
* Performs reduction using one against all strategy.
* For a multiclass classification with k classes, train k models (one per class).
* Each example is scored against all k models and the model with highest score
* is picked to label the example.
*/
@Experimental
final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
/** @group setParam */
// TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed.
def setClassifier(value: Classifier[_,_,_]): this.type = {
set(classifier, value.asInstanceOf[ClassifierType])
}
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
}
override def fit(dataset: DataFrame): OneVsRestModel = {
// determine number of classes either from metadata if provided, or via computation.
val labelSchema = dataset.schema($(labelCol))
val computeNumClasses: () => Int = () => {
val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
// classes are assumed to be numbered from 0,...,maxLabelIndex
maxLabelIndex.toInt + 1
}
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
}
// create k columns, one for each binary classifier.
val models = Range(0, numClasses).par.map { index =>
val label: Double => Double = (label: Double) => {
if (label.toInt == index) 1.0 else 0.0
}
// generate new label metadata for the binary problem.
// TODO: use when ... otherwise after SPARK-7321 is merged
val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
val classifier = getClassifier
classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
}.toArray[ClassificationModel[_,_]]
if (handlePersistence) {
multiclassLabeled.unpersist()
}
// extract label metadata from label column if present, or create a nominal attribute
// to output the number of labels
val labelAttribute = Attribute.fromStructField(labelSchema) match {
case _: NumericAttribute | UnresolvedAttribute => {
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
}
case attr: Attribute => attr
}
copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))
}
}

View file

@ -20,8 +20,7 @@ package org.apache.spark.ml.util
import scala.collection.immutable.HashMap import scala.collection.immutable.HashMap
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, import org.apache.spark.ml.attribute._
NumericAttribute}
import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructField
@ -39,9 +38,9 @@ object MetadataUtils {
*/ */
def getNumClasses(labelSchema: StructField): Option[Int] = { def getNumClasses(labelSchema: StructField): Option[Int] = {
Attribute.fromStructField(labelSchema) match { Attribute.fromStructField(labelSchema) match {
case numAttr: NumericAttribute => None
case binAttr: BinaryAttribute => Some(2) case binAttr: BinaryAttribute => Some(2)
case nomAttr: NominalAttribute => nomAttr.getNumValues case nomAttr: NominalAttribute => nomAttr.getNumValues
case _: NumericAttribute | UnresolvedAttribute => None
} }
} }
@ -65,7 +64,7 @@ object MetadataUtils {
Iterator() Iterator()
} else { } else {
attr match { attr match {
case numAttr: NumericAttribute => Iterator() case _: NumericAttribute | UnresolvedAttribute => Iterator()
case binAttr: BinaryAttribute => Iterator(idx -> 2) case binAttr: BinaryAttribute => Iterator(idx -> 2)
case nomAttr: NominalAttribute => case nomAttr: NominalAttribute =>
nomAttr.getNumValues match { nomAttr.getNumValues match {

View file

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.reduction;
import java.io.Serializable;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import static scala.collection.JavaConversions.seqAsJavaList;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
public class JavaOneVsRestSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient DataFrame dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
jsql = new SQLContext(jsc);
int nPoints = 3;
/**
* The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
* As a result, we are actually drawing samples from probability distribution of built model.
*/
double[] weights = {
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
double[] xMean = {5.843, 3.057, 3.758, 1.199};
double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
List<LabeledPoint> points = seqAsJavaList(generateMultinomialLogisticInput(
weights, xMean, xVariance, true, nPoints, 42));
datasetRDD = jsc.parallelize(points, 2);
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
}
@Test
public void oneVsRestDefaultParams() {
OneVsRest ova = new OneVsRest();
ova.setClassifier(new LogisticRegression());
Assert.assertEquals(ova.getLabelCol() , "label");
Assert.assertEquals(ova.getPredictionCol() , "prediction");
OneVsRestModel ovaModel = ova.fit(dataset);
DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction");
predictions.collectAsList();
Assert.assertEquals(ovaModel.getLabelCol(), "label");
Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
}
}

View file

@ -19,7 +19,7 @@ package org.apache.spark.ml.attribute
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, Metadata} import org.apache.spark.sql.types._
class AttributeSuite extends FunSuite { class AttributeSuite extends FunSuite {
@ -209,4 +209,12 @@ class AttributeSuite extends FunSuite {
intercept[IllegalArgumentException](attr.withName("")) intercept[IllegalArgumentException](attr.withName(""))
intercept[IllegalArgumentException](attr.withIndex(-1)) intercept[IllegalArgumentException](attr.withIndex(-1))
} }
test("attribute from struct field") {
val metadata = NumericAttribute.defaultAttr.withName("label").toMetadata()
val fldWithoutMeta = new StructField("x", DoubleType, false, Metadata.empty)
assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
val fldWithMeta = new StructField("x", DoubleType, false, metadata)
assert(Attribute.fromStructField(fldWithMeta).isNumeric)
}
} }

View file

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.reduction
import org.scalatest.FunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _
@transient var rdd: RDD[LabeledPoint] = _
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
val nPoints = 1000
/**
* The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
* As a result, we are actually drawing samples from probability distribution of built model.
*/
val weights = Array(
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
val xMean = Array(5.843, 3.057, 3.758, 1.199)
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
rdd = sc.parallelize(generateMultinomialLogisticInput(
weights, xMean, xVariance, true, nPoints, 42), 2)
dataset = sqlContext.createDataFrame(rdd)
}
test("one-vs-rest: default params") {
val numClasses = 3
val ova = new OneVsRest()
ova.setClassifier(new LogisticRegression)
assert(ova.getLabelCol === "label")
assert(ova.getPredictionCol === "prediction")
val ovaModel = ova.fit(dataset)
assert(ovaModel.models.size === numClasses)
val transformedDataset = ovaModel.transform(dataset)
// check for label metadata in prediction col
val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol)
assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3))
val ovaResults = transformedDataset
.select("prediction", "label")
.map(row => (row.getDouble(0), row.getDouble(1)))
val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses)
lr.optimizer.setRegParam(0.1).setNumIterations(100)
val model = lr.run(rdd)
val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label))
// determine the #confusion matrix in each class.
// bound how much error we allow compared to multinomial logistic regression.
val expectedMetrics = new MulticlassMetrics(results)
val ovaMetrics = new MulticlassMetrics(ovaResults)
assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400)
}
test("one-vs-rest: pass label metadata correctly during train") {
val numClasses = 3
val ova = new OneVsRest()
ova.setClassifier(new MockLogisticRegression)
val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata())
val features = dataset("features").as("features")
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
ova.fit(datasetWithLabelMetadata)
}
}
private class MockLogisticRegression extends LogisticRegression {
setMaxIter(1)
override protected def train(dataset: DataFrame): LogisticRegressionModel = {
val labelSchema = dataset.schema($(labelCol))
// check for label attribute propagation.
assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
super.train(dataset)
}
}