[SPARK-8764] [ML] string indexer should take option to handle unseen values

As a precursor to adding a public constructor add an option to handle unseen values by skipping rather than throwing an exception (default remains throwing an exception),

Author: Holden Karau <holden@pigscanfly.ca>

Closes #7266 from holdenk/SPARK-8764-string-indexer-should-take-option-to-handle-unseen-values and squashes the following commits:

38a4de9 [Holden Karau] fix long line
045bf22 [Holden Karau] Add a second b entry so b gets 0 for sure
81dd312 [Holden Karau] Update the docs for handleInvalid param to be more descriptive
7f37f6e [Holden Karau] remove extra space (scala style)
414e249 [Holden Karau] And switch to using handleInvalid instead of skipInvalid
1e53f9b [Holden Karau] update the param (codegen side)
7a22215 [Holden Karau] fix typo
100a39b [Holden Karau] Merge in master
aa5b093 [Holden Karau] Since we filter we should never go down this code path if getSkipInvalid is true
75ffa69 [Holden Karau] Remove extra newline
d69ef5e [Holden Karau] Add a test
b5734be [Holden Karau] Add support for unseen labels
afecd4e [Holden Karau] Add a param to skip invalid entries.
This commit is contained in:
Holden Karau 2015-08-11 11:33:36 -07:00 committed by Joseph K. Bradley
parent 8cad854ef6
commit dbd778d84d
4 changed files with 73 additions and 4 deletions

View file

@ -33,7 +33,8 @@ import org.apache.spark.util.collection.OpenHashMap
/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
with HasHandleInvalid {
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
@ -65,13 +66,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
def this() = this(Identifiable.randomUID("strIdx"))
/** @group setParam */
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
// TODO: handle unseen labels
override def fit(dataset: DataFrame): StringIndexerModel = {
val counts = dataset.select(col($(inputCol)).cast(StringType))
@ -111,6 +115,10 @@ class StringIndexerModel private[ml] (
map
}
/** @group setParam */
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@ -128,14 +136,24 @@ class StringIndexerModel private[ml] (
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else {
// TODO: handle unseen labels
throw new SparkException(s"Unseen label: $label.")
}
}
val outputColName = $(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toMetadata()
dataset.select(col("*"),
// If we are skipping invalid records, filter them out.
val filteredDataset = (getHandleInvalid) match {
case "skip" => {
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
dataset.where(filterer(dataset($(inputCol))))
}
case _ => dataset
}
filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
}

View file

@ -59,6 +59,10 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
isValid = "ParamValidators.gtEq(1)"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
"will filter out rows with bad values), or error (which will throw an errror). More " +
"options may be added later.",
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
" before fitting the model.", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),

View file

@ -247,6 +247,21 @@ private[ml] trait HasFitIntercept extends Params {
final def getFitIntercept: Boolean = $(fitIntercept)
}
/**
* Trait for shared param handleInvalid.
*/
private[ml] trait HasHandleInvalid extends Params {
/**
* Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later..
* @group param
*/
final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error")))
/** @group getParam */
final def getHandleInvalid: String = $(handleInvalid)
}
/**
* Trait for shared param standardization (default: true).
*/

View file

@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
@ -62,6 +63,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
}
test("StringIndexerUnseen") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
// Verify we throw by default with unseen values
intercept[SparkException] {
indexer.transform(df2).collect()
}
val indexerSkipInvalid = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.setHandleInvalid("skip")
.fit(df)
// Verify that we skip the c record
val transformed = indexerSkipInvalid.transform(df2)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("b", "a"))
val output = transformed.select("id", "labelIndex").map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0
val expected = Set((0, 1.0), (1, 0.0))
assert(output === expected)
}
test("StringIndexer with a numeric input column") {
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")