[SPARK-13030][ML] Create OneHotEncoderEstimator for OneHotEncoder as Estimator

## What changes were proposed in this pull request?

This patch adds a new class `OneHotEncoderEstimator` which extends `Estimator`. The `fit` method returns `OneHotEncoderModel`.

Common methods between existing `OneHotEncoder` and new `OneHotEncoderEstimator`, such as transforming schema, are extracted and put into `OneHotEncoderCommon` to reduce code duplication.

### Multi-column support

`OneHotEncoderEstimator` adds simpler multi-column support because it is new API and can be free from backward compatibility.

### handleInvalid Param support

`OneHotEncoderEstimator` supports `handleInvalid` Param. It supports `error` and `keep`.

## How was this patch tested?

Added new test suite `OneHotEncoderEstimatorSuite`.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #19527 from viirya/SPARK-13030.
This commit is contained in:
Liang-Chi Hsieh 2017-12-31 15:28:59 -08:00 committed by Joseph K. Bradley
parent 5955a2d0fb
commit 994065d891
3 changed files with 960 additions and 66 deletions

View file

@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
* The output vectors are sparse.
*
* @see `StringIndexer` for converting categorical values into category indices
* @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`
* will be removed in 3.0.0.
*/
@Since("1.4.0")
@deprecated("`OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`" +
" will be removed in 3.0.0.", "2.3.0")
class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer
with HasInputCol with HasOutputCol with DefaultParamsWritable {
@ -78,56 +82,16 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e
override def transformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
val outputColName = $(outputCol)
val inputFields = schema.fields
require(schema(inputColName).dataType.isInstanceOf[NumericType],
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
val inputFields = schema.fields
require(!inputFields.exists(_.name == outputColName),
s"Output column $outputColName already exists.")
val inputAttr = Attribute.fromStructField(schema(inputColName))
val outputAttrNames: Option[Array[String]] = inputAttr match {
case nominal: NominalAttribute =>
if (nominal.values.isDefined) {
nominal.values
} else if (nominal.numValues.isDefined) {
nominal.numValues.map(n => Array.tabulate(n)(_.toString))
} else {
None
}
case binary: BinaryAttribute =>
if (binary.values.isDefined) {
binary.values
} else {
Some(Array.tabulate(2)(_.toString))
}
case _: NumericAttribute =>
throw new RuntimeException(
s"The input column $inputColName cannot be numeric.")
case _ =>
None // optimistic about unknown attributes
}
val filteredOutputAttrNames = outputAttrNames.map { names =>
if ($(dropLast)) {
require(names.length > 1,
s"The input column $inputColName should have at least two distinct values.")
names.dropRight(1)
} else {
names
}
}
val outputAttrGroup = if (filteredOutputAttrNames.isDefined) {
val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name =>
BinaryAttribute.defaultAttr.withName(name)
}
new AttributeGroup($(outputCol), attrs)
} else {
new AttributeGroup($(outputCol))
}
val outputFields = inputFields :+ outputAttrGroup.toStructField()
val outputField = OneHotEncoderCommon.transformOutputColumnSchema(
schema(inputColName), outputColName, $(dropLast))
val outputFields = inputFields :+ outputField
StructType(outputFields)
}
@ -136,30 +100,17 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e
// schema transformation
val inputColName = $(inputCol)
val outputColName = $(outputCol)
val shouldDropLast = $(dropLast)
var outputAttrGroup = AttributeGroup.fromStructField(
val outputAttrGroupFromSchema = AttributeGroup.fromStructField(
transformSchema(dataset.schema)(outputColName))
if (outputAttrGroup.size < 0) {
// If the number of attributes is unknown, we check the values from the input column.
val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0))
.treeAggregate(0.0)(
(m, x) => {
assert(x <= Int.MaxValue,
s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x")
assert(x >= 0.0 && x == x.toInt,
s"Values from column $inputColName must be indices, but got $x.")
math.max(m, x)
},
(m0, m1) => {
math.max(m0, m1)
}
).toInt + 1
val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
val outputAttrs: Array[Attribute] =
filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
outputAttrGroup = new AttributeGroup(outputColName, outputAttrs)
val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) {
OneHotEncoderCommon.getOutputAttrGroupFromData(
dataset, Seq(inputColName), Seq(outputColName), $(dropLast))(0)
} else {
outputAttrGroupFromSchema
}
val metadata = outputAttrGroup.toMetadata()
// data transformation

View file

@ -0,0 +1,522 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */
private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
with HasInputCols with HasOutputCols {
/**
* Param for how to handle invalid data.
* Options are 'keep' (invalid data presented as an extra categorical feature) or
* 'error' (throw an error).
* Default: "error"
* @group param
*/
@Since("2.3.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"How to handle invalid data " +
"Options are 'keep' (invalid data presented as an extra categorical feature) " +
"or error (throw an error).",
ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids))
setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
/**
* Whether to drop the last category in the encoded vector (default: true)
* @group param
*/
@Since("2.3.0")
final val dropLast: BooleanParam =
new BooleanParam(this, "dropLast", "whether to drop the last category")
setDefault(dropLast -> true)
/** @group getParam */
@Since("2.3.0")
def getDropLast: Boolean = $(dropLast)
protected def validateAndTransformSchema(
schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = {
val inputColNames = $(inputCols)
val outputColNames = $(outputCols)
val existingFields = schema.fields
require(inputColNames.length == outputColNames.length,
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
s"output columns ${outputColNames.length}.")
// Input columns must be NumericType.
inputColNames.foreach(SchemaUtils.checkNumericType(schema, _))
// Prepares output columns with proper attributes by examining input columns.
val inputFields = $(inputCols).map(schema(_))
val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) =>
OneHotEncoderCommon.transformOutputColumnSchema(
inputField, outputColName, dropLast, keepInvalid)
}
outputFields.foldLeft(schema) { case (newSchema, outputField) =>
SchemaUtils.appendColumn(newSchema, outputField)
}
}
}
/**
* A one-hot encoder that maps a column of category indices to a column of binary vectors, with
* at most a single one-value per row that indicates the input category index.
* For example with 5 categories, an input value of 2.0 would map to an output vector of
* `[0.0, 0.0, 1.0, 0.0]`.
* The last category is not included by default (configurable via `dropLast`),
* because it makes the vector entries sum up to one, and hence linearly dependent.
* So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
*
* @note This is different from scikit-learn's OneHotEncoder, which keeps all categories.
* The output vectors are sparse.
*
* When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is
* added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros
* vector.
*
* @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols
* come in pairs, specified by the order in the arrays, and each pair is treated independently.
*
* @see `StringIndexer` for converting categorical values into category indices
*/
@Since("2.3.0")
class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: String)
extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with DefaultParamsWritable {
@Since("2.3.0")
def this() = this(Identifiable.randomUID("oneHotEncoder"))
/** @group setParam */
@Since("2.3.0")
def setInputCols(values: Array[String]): this.type = set(inputCols, values)
/** @group setParam */
@Since("2.3.0")
def setOutputCols(values: Array[String]): this.type = set(outputCols, values)
/** @group setParam */
@Since("2.3.0")
def setDropLast(value: Boolean): this.type = set(dropLast, value)
/** @group setParam */
@Since("2.3.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
@Since("2.3.0")
override def transformSchema(schema: StructType): StructType = {
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
validateAndTransformSchema(schema, dropLast = $(dropLast),
keepInvalid = keepInvalid)
}
@Since("2.3.0")
override def fit(dataset: Dataset[_]): OneHotEncoderModel = {
transformSchema(dataset.schema)
// Compute the plain number of categories without `handleInvalid` and
// `dropLast` taken into account.
val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false,
keepInvalid = false)
val categorySizes = new Array[Int]($(outputCols).length)
val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) =>
val numOfAttrs = AttributeGroup.fromStructField(
transformedSchema(outputColName)).size
if (numOfAttrs < 0) {
Some(idx)
} else {
categorySizes(idx) = numOfAttrs
None
}
}
// Some input columns don't have attributes or their attributes don't have necessary info.
// We need to scan the data to get the number of values for each column.
if (columnToScanIndices.length > 0) {
val inputColNames = columnToScanIndices.map($(inputCols)(_))
val outputColNames = columnToScanIndices.map($(outputCols)(_))
// When fitting data, we want the plain number of categories without `handleInvalid` and
// `dropLast` taken into account.
val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData(
dataset, inputColNames, outputColNames, dropLast = false)
attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) =>
categorySizes(idx) = attrGroup.size
}
}
val model = new OneHotEncoderModel(uid, categorySizes).setParent(this)
copyValues(model)
}
@Since("2.3.0")
override def copy(extra: ParamMap): OneHotEncoderEstimator = defaultCopy(extra)
}
@Since("2.3.0")
object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimator] {
private[feature] val KEEP_INVALID: String = "keep"
private[feature] val ERROR_INVALID: String = "error"
private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID)
@Since("2.3.0")
override def load(path: String): OneHotEncoderEstimator = super.load(path)
}
@Since("2.3.0")
class OneHotEncoderModel private[ml] (
@Since("2.3.0") override val uid: String,
@Since("2.3.0") val categorySizes: Array[Int])
extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable {
import OneHotEncoderModel._
// Returns the category size for a given index with `dropLast` and `handleInvalid`
// taken into account.
private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = {
val dropLast = getDropLast
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
if (!dropLast && keepInvalid) {
// When `handleInvalid` is "keep", an extra category is added as last category
// for invalid data.
orgCategorySize + 1
} else if (dropLast && !keepInvalid) {
// When `dropLast` is true, the last category is removed.
orgCategorySize - 1
} else {
// When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid
// data is removed. Thus, it is the same as the plain number of categories.
orgCategorySize
}
}
private def encoder: UserDefinedFunction = {
val oneValue = Array(1.0)
val emptyValues = Array.empty[Double]
val emptyIndices = Array.empty[Int]
val dropLast = getDropLast
val handleInvalid = getHandleInvalid
val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID
// The udf performed on input data. The first parameter is the input value. The second
// parameter is the index of input.
udf { (label: Double, idx: Int) =>
val plainNumCategories = categorySizes(idx)
val size = configedCategorySize(plainNumCategories, idx)
if (label < 0) {
throw new SparkException(s"Negative value: $label. Input can't be negative.")
} else if (label == size && dropLast && !keepInvalid) {
// When `dropLast` is true and `handleInvalid` is not "keep",
// the last category is removed.
Vectors.sparse(size, emptyIndices, emptyValues)
} else if (label >= plainNumCategories && keepInvalid) {
// When `handleInvalid` is "keep", encodes invalid data to last category (and removed
// if `dropLast` is true)
if (dropLast) {
Vectors.sparse(size, emptyIndices, emptyValues)
} else {
Vectors.sparse(size, Array(size - 1), oneValue)
}
} else if (label < plainNumCategories) {
Vectors.sparse(size, Array(label.toInt), oneValue)
} else {
assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID)
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
}
}
}
/** @group setParam */
@Since("2.3.0")
def setInputCols(values: Array[String]): this.type = set(inputCols, values)
/** @group setParam */
@Since("2.3.0")
def setOutputCols(values: Array[String]): this.type = set(outputCols, values)
/** @group setParam */
@Since("2.3.0")
def setDropLast(value: Boolean): this.type = set(dropLast, value)
/** @group setParam */
@Since("2.3.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
@Since("2.3.0")
override def transformSchema(schema: StructType): StructType = {
val inputColNames = $(inputCols)
val outputColNames = $(outputCols)
require(inputColNames.length == categorySizes.length,
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
s"features ${categorySizes.length} during fitting.")
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
val transformedSchema = validateAndTransformSchema(schema, dropLast = $(dropLast),
keepInvalid = keepInvalid)
verifyNumOfValues(transformedSchema)
}
/**
* If the metadata of input columns also specifies the number of categories, we need to
* compare with expected category number with `handleInvalid` and `dropLast` taken into
* account. Mismatched numbers will cause exception.
*/
private def verifyNumOfValues(schema: StructType): StructType = {
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
val inputColName = $(inputCols)(idx)
val attrGroup = AttributeGroup.fromStructField(schema(outputColName))
// If the input metadata specifies number of category for output column,
// comparing with expected category number with `handleInvalid` and
// `dropLast` taken into account.
if (attrGroup.attributes.nonEmpty) {
val numCategories = configedCategorySize(categorySizes(idx), idx)
require(attrGroup.size == numCategories, "OneHotEncoderModel expected " +
s"$numCategories categorical values for input column ${inputColName}, " +
s"but the input column had metadata specifying ${attrGroup.size} values.")
}
}
schema
}
@Since("2.3.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
val encodedColumns = (0 until $(inputCols).length).map { idx =>
val inputColName = $(inputCols)(idx)
val outputColName = $(outputCols)(idx)
val outputAttrGroupFromSchema =
AttributeGroup.fromStructField(transformedSchema(outputColName))
val metadata = if (outputAttrGroupFromSchema.size < 0) {
OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName,
categorySizes(idx), $(dropLast), keepInvalid).toMetadata()
} else {
outputAttrGroupFromSchema.toMetadata()
}
encoder(col(inputColName).cast(DoubleType), lit(idx))
.as(outputColName, metadata)
}
dataset.withColumns($(outputCols), encodedColumns)
}
@Since("2.3.0")
override def copy(extra: ParamMap): OneHotEncoderModel = {
val copied = new OneHotEncoderModel(uid, categorySizes)
copyValues(copied, extra).setParent(parent)
}
@Since("2.3.0")
override def write: MLWriter = new OneHotEncoderModelWriter(this)
}
@Since("2.3.0")
object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] {
private[OneHotEncoderModel]
class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends MLWriter {
private case class Data(categorySizes: Array[Int])
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.categorySizes)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class OneHotEncoderModelReader extends MLReader[OneHotEncoderModel] {
private val className = classOf[OneHotEncoderModel].getName
override def load(path: String): OneHotEncoderModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("categorySizes")
.head()
val categorySizes = data.getAs[Seq[Int]](0).toArray
val model = new OneHotEncoderModel(metadata.uid, categorySizes)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
@Since("2.3.0")
override def read: MLReader[OneHotEncoderModel] = new OneHotEncoderModelReader
@Since("2.3.0")
override def load(path: String): OneHotEncoderModel = super.load(path)
}
/**
* Provides some helper methods used by both `OneHotEncoder` and `OneHotEncoderEstimator`.
*/
private[feature] object OneHotEncoderCommon {
private def genOutputAttrNames(inputCol: StructField): Option[Array[String]] = {
val inputAttr = Attribute.fromStructField(inputCol)
inputAttr match {
case nominal: NominalAttribute =>
if (nominal.values.isDefined) {
nominal.values
} else if (nominal.numValues.isDefined) {
nominal.numValues.map(n => Array.tabulate(n)(_.toString))
} else {
None
}
case binary: BinaryAttribute =>
if (binary.values.isDefined) {
binary.values
} else {
Some(Array.tabulate(2)(_.toString))
}
case _: NumericAttribute =>
throw new RuntimeException(
s"The input column ${inputCol.name} cannot be continuous-value.")
case _ =>
None // optimistic about unknown attributes
}
}
/** Creates an `AttributeGroup` filled by the `BinaryAttribute` named as required. */
private def genOutputAttrGroup(
outputAttrNames: Option[Array[String]],
outputColName: String): AttributeGroup = {
outputAttrNames.map { attrNames =>
val attrs: Array[Attribute] = attrNames.map { name =>
BinaryAttribute.defaultAttr.withName(name)
}
new AttributeGroup(outputColName, attrs)
}.getOrElse{
new AttributeGroup(outputColName)
}
}
/**
* Prepares the `StructField` with proper metadata for `OneHotEncoder`'s output column.
*/
def transformOutputColumnSchema(
inputCol: StructField,
outputColName: String,
dropLast: Boolean,
keepInvalid: Boolean = false): StructField = {
val outputAttrNames = genOutputAttrNames(inputCol)
val filteredOutputAttrNames = outputAttrNames.map { names =>
if (dropLast && !keepInvalid) {
require(names.length > 1,
s"The input column ${inputCol.name} should have at least two distinct values.")
names.dropRight(1)
} else if (!dropLast && keepInvalid) {
names ++ Seq("invalidValues")
} else {
names
}
}
genOutputAttrGroup(filteredOutputAttrNames, outputColName).toStructField()
}
/**
* This method is called when we want to generate `AttributeGroup` from actual data for
* one-hot encoder.
*/
def getOutputAttrGroupFromData(
dataset: Dataset[_],
inputColNames: Seq[String],
outputColNames: Seq[String],
dropLast: Boolean): Seq[AttributeGroup] = {
// The RDD approach has advantage of early-stop if any values are invalid. It seems that
// DataFrame ops don't have equivalent functions.
val columns = inputColNames.map { inputColName =>
col(inputColName).cast(DoubleType)
}
val numOfColumns = columns.length
val numAttrsArray = dataset.select(columns: _*).rdd.map { row =>
(0 until numOfColumns).map(idx => row.getDouble(idx)).toArray
}.treeAggregate(new Array[Double](numOfColumns))(
(maxValues, curValues) => {
(0 until numOfColumns).foreach { idx =>
val x = curValues(idx)
assert(x <= Int.MaxValue,
s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x.")
assert(x >= 0.0 && x == x.toInt,
s"Values from column ${inputColNames(idx)} must be indices, but got $x.")
maxValues(idx) = math.max(maxValues(idx), x)
}
maxValues
},
(m0, m1) => {
(0 until numOfColumns).foreach { idx =>
m0(idx) = math.max(m0(idx), m1(idx))
}
m0
}
).map(_.toInt + 1)
outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) =>
createAttrGroupForAttrNames(outputColName, numAttrs, dropLast, keepInvalid = false)
}
}
/** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */
def createAttrGroupForAttrNames(
outputColName: String,
numAttrs: Int,
dropLast: Boolean,
keepInvalid: Boolean): AttributeGroup = {
val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
val filtered = if (dropLast && !keepInvalid) {
outputAttrNames.dropRight(1)
} else if (!dropLast && keepInvalid) {
outputAttrNames ++ Seq("invalidValues")
} else {
outputAttrNames
}
genOutputAttrGroup(Some(filtered), outputColName)
}
}

View file

@ -0,0 +1,421 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
class OneHotEncoderEstimatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import testImplicits._
test("params") {
ParamsSuite.checkParams(new OneHotEncoderEstimator)
}
test("OneHotEncoderEstimator dropLast = false") {
val data = Seq(
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))))
val schema = StructType(Array(
StructField("input", DoubleType),
StructField("expected", new VectorUDT)))
val df = spark.createDataFrame(sc.parallelize(data), schema)
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("input"))
.setOutputCols(Array("output"))
assert(encoder.getDropLast === true)
encoder.setDropLast(false)
assert(encoder.getDropLast === false)
val model = encoder.fit(df)
val encoded = model.transform(df)
encoded.select("output", "expected").rdd.map { r =>
(r.getAs[Vector](0), r.getAs[Vector](1))
}.collect().foreach { case (vec1, vec2) =>
assert(vec1 === vec2)
}
}
test("OneHotEncoderEstimator dropLast = true") {
val data = Seq(
Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))),
Row(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
Row(2.0, Vectors.sparse(2, Seq())),
Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))),
Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))),
Row(2.0, Vectors.sparse(2, Seq())))
val schema = StructType(Array(
StructField("input", DoubleType),
StructField("expected", new VectorUDT)))
val df = spark.createDataFrame(sc.parallelize(data), schema)
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("input"))
.setOutputCols(Array("output"))
val model = encoder.fit(df)
val encoded = model.transform(df)
encoded.select("output", "expected").rdd.map { r =>
(r.getAs[Vector](0), r.getAs[Vector](1))
}.collect().foreach { case (vec1, vec2) =>
assert(vec1 === vec2)
}
}
test("input column with ML attribute") {
val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size")
.select(col("size").as("size", attr.toMetadata()))
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("size"))
.setOutputCols(Array("encoded"))
val model = encoder.fit(df)
val output = model.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
}
test("input column without ML attribute") {
val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index")
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("index"))
.setOutputCols(Array("encoded"))
val model = encoder.fit(df)
val output = model.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
}
test("read/write") {
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("index"))
.setOutputCols(Array("encoded"))
testDefaultReadWrite(encoder)
}
test("OneHotEncoderModel read/write") {
val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3))
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.categorySizes === instance.categorySizes)
}
test("OneHotEncoderEstimator with varying types") {
val data = Seq(
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))))
val schema = StructType(Array(
StructField("input", DoubleType),
StructField("expected", new VectorUDT)))
val df = spark.createDataFrame(sc.parallelize(data), schema)
val dfWithTypes = df
.withColumn("shortInput", df("input").cast(ShortType))
.withColumn("longInput", df("input").cast(LongType))
.withColumn("intInput", df("input").cast(IntegerType))
.withColumn("floatInput", df("input").cast(FloatType))
.withColumn("decimalInput", df("input").cast(DecimalType(10, 0)))
val cols = Array("input", "shortInput", "longInput", "intInput",
"floatInput", "decimalInput")
for (col <- cols) {
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array(col))
.setOutputCols(Array("output"))
.setDropLast(false)
val model = encoder.fit(dfWithTypes)
val encoded = model.transform(dfWithTypes)
encoded.select("output", "expected").rdd.map { r =>
(r.getAs[Vector](0), r.getAs[Vector](1))
}.collect().foreach { case (vec1, vec2) =>
assert(vec1 === vec2)
}
}
}
test("OneHotEncoderEstimator: encoding multiple columns and dropLast = false") {
val data = Seq(
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))),
Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))))
val schema = StructType(Array(
StructField("input1", DoubleType),
StructField("expected1", new VectorUDT),
StructField("input2", DoubleType),
StructField("expected2", new VectorUDT)))
val df = spark.createDataFrame(sc.parallelize(data), schema)
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("output1", "output2"))
assert(encoder.getDropLast === true)
encoder.setDropLast(false)
assert(encoder.getDropLast === false)
val model = encoder.fit(df)
val encoded = model.transform(df)
encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r =>
(r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3))
}.collect().foreach { case (vec1, vec2, vec3, vec4) =>
assert(vec1 === vec2)
assert(vec3 === vec4)
}
}
test("OneHotEncoderEstimator: encoding multiple columns and dropLast = true") {
val data = Seq(
Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))),
Row(1.0, Vectors.sparse(2, Seq((1, 1.0))), 3.0, Vectors.sparse(3, Seq())),
Row(2.0, Vectors.sparse(2, Seq()), 0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 1.0, Vectors.sparse(3, Seq((1, 1.0)))),
Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(2.0, Vectors.sparse(2, Seq()), 2.0, Vectors.sparse(3, Seq((2, 1.0)))))
val schema = StructType(Array(
StructField("input1", DoubleType),
StructField("expected1", new VectorUDT),
StructField("input2", DoubleType),
StructField("expected2", new VectorUDT)))
val df = spark.createDataFrame(sc.parallelize(data), schema)
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("output1", "output2"))
val model = encoder.fit(df)
val encoded = model.transform(df)
encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r =>
(r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3))
}.collect().foreach { case (vec1, vec2, vec3, vec4) =>
assert(vec1 === vec2)
assert(vec3 === vec4)
}
}
test("Throw error on invalid values") {
val trainingData = Seq((0, 0), (1, 1), (2, 2))
val trainingDF = trainingData.toDF("id", "a")
val testData = Seq((0, 0), (1, 2), (1, 3))
val testDF = testData.toDF("id", "a")
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("a"))
.setOutputCols(Array("encoded"))
val model = encoder.fit(trainingDF)
val err = intercept[SparkException] {
model.transform(testDF).show
}
err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
}
test("Can't transform on negative input") {
val trainingDF = Seq((0, 0), (1, 1), (2, 2)).toDF("a", "b")
val testDF = Seq((0, 0), (-1, 2), (1, 3)).toDF("a", "b")
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("a"))
.setOutputCols(Array("encoded"))
val model = encoder.fit(trainingDF)
val err = intercept[SparkException] {
model.transform(testDF).collect()
}
err.getMessage.contains("Negative value: -1.0. Input can't be negative")
}
test("Keep on invalid values: dropLast = false") {
val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input")
val testData = Seq(
Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))),
Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))),
Row(3.0, Vectors.sparse(4, Seq((3, 1.0)))))
val schema = StructType(Array(
StructField("input", DoubleType),
StructField("expected", new VectorUDT)))
val testDF = spark.createDataFrame(sc.parallelize(testData), schema)
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("input"))
.setOutputCols(Array("output"))
.setHandleInvalid("keep")
.setDropLast(false)
val model = encoder.fit(trainingDF)
val encoded = model.transform(testDF)
encoded.select("output", "expected").rdd.map { r =>
(r.getAs[Vector](0), r.getAs[Vector](1))
}.collect().foreach { case (vec1, vec2) =>
assert(vec1 === vec2)
}
}
test("Keep on invalid values: dropLast = true") {
val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input")
val testData = Seq(
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
Row(3.0, Vectors.sparse(3, Seq())))
val schema = StructType(Array(
StructField("input", DoubleType),
StructField("expected", new VectorUDT)))
val testDF = spark.createDataFrame(sc.parallelize(testData), schema)
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("input"))
.setOutputCols(Array("output"))
.setHandleInvalid("keep")
.setDropLast(true)
val model = encoder.fit(trainingDF)
val encoded = model.transform(testDF)
encoded.select("output", "expected").rdd.map { r =>
(r.getAs[Vector](0), r.getAs[Vector](1))
}.collect().foreach { case (vec1, vec2) =>
assert(vec1 === vec2)
}
}
test("OneHotEncoderModel changes dropLast") {
val data = Seq(
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))),
Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), Vectors.sparse(2, Seq((1, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())))
val schema = StructType(Array(
StructField("input", DoubleType),
StructField("expected1", new VectorUDT),
StructField("expected2", new VectorUDT)))
val df = spark.createDataFrame(sc.parallelize(data), schema)
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("input"))
.setOutputCols(Array("output"))
val model = encoder.fit(df)
model.setDropLast(false)
val encoded1 = model.transform(df)
encoded1.select("output", "expected1").rdd.map { r =>
(r.getAs[Vector](0), r.getAs[Vector](1))
}.collect().foreach { case (vec1, vec2) =>
assert(vec1 === vec2)
}
model.setDropLast(true)
val encoded2 = model.transform(df)
encoded2.select("output", "expected2").rdd.map { r =>
(r.getAs[Vector](0), r.getAs[Vector](1))
}.collect().foreach { case (vec1, vec2) =>
assert(vec1 === vec2)
}
}
test("OneHotEncoderModel changes handleInvalid") {
val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input")
val testData = Seq(
Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))),
Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))),
Row(3.0, Vectors.sparse(4, Seq((3, 1.0)))))
val schema = StructType(Array(
StructField("input", DoubleType),
StructField("expected", new VectorUDT)))
val testDF = spark.createDataFrame(sc.parallelize(testData), schema)
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("input"))
.setOutputCols(Array("output"))
val model = encoder.fit(trainingDF)
model.setHandleInvalid("error")
val err = intercept[SparkException] {
model.transform(testDF).collect()
}
err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
model.setHandleInvalid("keep")
model.transform(testDF).collect()
}
test("Transforming on mismatched attributes") {
val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size")
.select(col("size").as("size", attr.toMetadata()))
val encoder = new OneHotEncoderEstimator()
.setInputCols(Array("size"))
.setOutputCols(Array("encoded"))
val model = encoder.fit(df)
val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large")
val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size")
.select(col("size").as("size", testAttr.toMetadata()))
val err = intercept[Exception] {
model.transform(testDF).collect()
}
err.getMessage.contains("OneHotEncoderModel expected 2 categorical values")
}
}