[SPARK-5888] [MLLIB] Add OneHotEncoder as a Transformer
This patch adds a one hot encoder for categorical features. Planning to add documentation and another test after getting feedback on the approach. A couple choices made here: * There's an `includeFirst` option which, if false, creates numCategories - 1 columns and, if true, creates numCategories columns. The default is true, which is the behavior in scikit-learn. * The user is expected to pass a `Seq` of category names when instantiating a `OneHotEncoder`. These can be easily gotten from a `StringIndexer`. The names are used for the output column names, which take the form colName_categoryName. Author: Sandy Ryza <sandy@cloudera.com> Closes #5500 from sryza/sandy-spark-5888 and squashes the following commits: f383250 [Sandy Ryza] Infer label names automatically 6e257b9 [Sandy Ryza] Review comments 7c539cf [Sandy Ryza] Vector transformers 1c182dd [Sandy Ryza] SPARK-5888. [MLLIB]. Add OneHotEncoder as a Transformer
This commit is contained in:
parent
ee374e89cd
commit
47728db7cf
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.annotation.AlphaComponent
|
||||
import org.apache.spark.ml.UnaryTransformer
|
||||
import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||
import org.apache.spark.ml.util.SchemaUtils
|
||||
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
||||
|
||||
/**
|
||||
* A one-hot encoder that maps a column of label indices to a column of binary vectors, with
|
||||
* at most a single one-value. By default, the binary vector has an element for each category, so
|
||||
* with 5 categories, an input value of 2.0 would map to an output vector of
|
||||
* (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
|
||||
* output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
|
||||
* of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
|
||||
* linearly dependent because they sum up to one.
|
||||
*/
|
||||
@AlphaComponent
|
||||
class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
|
||||
with HasInputCol with HasOutputCol {
|
||||
|
||||
/**
|
||||
* Whether to include a component in the encoded vectors for the first category, defaults to true.
|
||||
* @group param
|
||||
*/
|
||||
final val includeFirst: BooleanParam =
|
||||
new BooleanParam(this, "includeFirst", "include first category")
|
||||
setDefault(includeFirst -> true)
|
||||
|
||||
private var categories: Array[String] = _
|
||||
|
||||
/** @group setParam */
|
||||
def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
|
||||
|
||||
/** @group setParam */
|
||||
override def setInputCol(value: String): this.type = set(inputCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
override def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
|
||||
val inputFields = schema.fields
|
||||
val outputColName = $(outputCol)
|
||||
require(inputFields.forall(_.name != $(outputCol)),
|
||||
s"Output column ${$(outputCol)} already exists.")
|
||||
|
||||
val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
|
||||
categories = inputColAttr match {
|
||||
case nominal: NominalAttribute =>
|
||||
nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
|
||||
case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
|
||||
case _ =>
|
||||
throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
|
||||
}
|
||||
|
||||
val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
|
||||
val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
|
||||
val outputFields = inputFields :+ attr.toStructField()
|
||||
StructType(outputFields)
|
||||
}
|
||||
|
||||
protected override def createTransformFunc(): (Double) => Vector = {
|
||||
val first = $(includeFirst)
|
||||
val vecLen = if (first) categories.length else categories.length - 1
|
||||
val oneValue = Array(1.0)
|
||||
val emptyValues = Array[Double]()
|
||||
val emptyIndices = Array[Int]()
|
||||
label: Double => {
|
||||
val values = if (first || label != 0.0) oneValue else emptyValues
|
||||
val indices = if (first) {
|
||||
Array(label.toInt)
|
||||
} else if (label != 0.0) {
|
||||
Array(label.toInt - 1)
|
||||
} else {
|
||||
emptyIndices
|
||||
}
|
||||
Vectors.sparse(vecLen, indices, values)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the data type of the output column.
|
||||
*/
|
||||
protected def outputDataType: DataType = new VectorUDT
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.mllib.linalg.Vector
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||
|
||||
|
||||
class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
|
||||
private var sqlContext: SQLContext = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
sqlContext = new SQLContext(sc)
|
||||
}
|
||||
|
||||
def stringIndexed(): DataFrame = {
|
||||
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
||||
val df = sqlContext.createDataFrame(data).toDF("id", "label")
|
||||
val indexer = new StringIndexer()
|
||||
.setInputCol("label")
|
||||
.setOutputCol("labelIndex")
|
||||
.fit(df)
|
||||
indexer.transform(df)
|
||||
}
|
||||
|
||||
test("OneHotEncoder includeFirst = true") {
|
||||
val transformed = stringIndexed()
|
||||
val encoder = new OneHotEncoder()
|
||||
.setInputCol("labelIndex")
|
||||
.setOutputCol("labelVec")
|
||||
val encoded = encoder.transform(transformed)
|
||||
|
||||
val output = encoded.select("id", "labelVec").map { r =>
|
||||
val vec = r.get(1).asInstanceOf[Vector]
|
||||
(r.getInt(0), vec(0), vec(1), vec(2))
|
||||
}.collect().toSet
|
||||
// a -> 0, b -> 2, c -> 1
|
||||
val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
|
||||
(3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
|
||||
assert(output === expected)
|
||||
}
|
||||
|
||||
test("OneHotEncoder includeFirst = false") {
|
||||
val transformed = stringIndexed()
|
||||
val encoder = new OneHotEncoder()
|
||||
.setIncludeFirst(false)
|
||||
.setInputCol("labelIndex")
|
||||
.setOutputCol("labelVec")
|
||||
val encoded = encoder.transform(transformed)
|
||||
|
||||
val output = encoded.select("id", "labelVec").map { r =>
|
||||
val vec = r.get(1).asInstanceOf[Vector]
|
||||
(r.getInt(0), vec(0), vec(1))
|
||||
}.collect().toSet
|
||||
// a -> 0, b -> 2, c -> 1
|
||||
val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0),
|
||||
(3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0))
|
||||
assert(output === expected)
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue