[SPARK-13030][ML] Create OneHotEncoderEstimator for OneHotEncoder as Estimator
## What changes were proposed in this pull request? This patch adds a new class `OneHotEncoderEstimator` which extends `Estimator`. The `fit` method returns `OneHotEncoderModel`. Common methods between existing `OneHotEncoder` and new `OneHotEncoderEstimator`, such as transforming schema, are extracted and put into `OneHotEncoderCommon` to reduce code duplication. ### Multi-column support `OneHotEncoderEstimator` adds simpler multi-column support because it is new API and can be free from backward compatibility. ### handleInvalid Param support `OneHotEncoderEstimator` supports `handleInvalid` Param. It supports `error` and `keep`. ## How was this patch tested? Added new test suite `OneHotEncoderEstimatorSuite`. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #19527 from viirya/SPARK-13030.
This commit is contained in:
parent
5955a2d0fb
commit
994065d891
|
@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
|
|||
* The output vectors are sparse.
|
||||
*
|
||||
* @see `StringIndexer` for converting categorical values into category indices
|
||||
* @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`
|
||||
* will be removed in 3.0.0.
|
||||
*/
|
||||
@Since("1.4.0")
|
||||
@deprecated("`OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`" +
|
||||
" will be removed in 3.0.0.", "2.3.0")
|
||||
class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer
|
||||
with HasInputCol with HasOutputCol with DefaultParamsWritable {
|
||||
|
||||
|
@ -78,56 +82,16 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e
|
|||
override def transformSchema(schema: StructType): StructType = {
|
||||
val inputColName = $(inputCol)
|
||||
val outputColName = $(outputCol)
|
||||
val inputFields = schema.fields
|
||||
|
||||
require(schema(inputColName).dataType.isInstanceOf[NumericType],
|
||||
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
|
||||
val inputFields = schema.fields
|
||||
require(!inputFields.exists(_.name == outputColName),
|
||||
s"Output column $outputColName already exists.")
|
||||
|
||||
val inputAttr = Attribute.fromStructField(schema(inputColName))
|
||||
val outputAttrNames: Option[Array[String]] = inputAttr match {
|
||||
case nominal: NominalAttribute =>
|
||||
if (nominal.values.isDefined) {
|
||||
nominal.values
|
||||
} else if (nominal.numValues.isDefined) {
|
||||
nominal.numValues.map(n => Array.tabulate(n)(_.toString))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
case binary: BinaryAttribute =>
|
||||
if (binary.values.isDefined) {
|
||||
binary.values
|
||||
} else {
|
||||
Some(Array.tabulate(2)(_.toString))
|
||||
}
|
||||
case _: NumericAttribute =>
|
||||
throw new RuntimeException(
|
||||
s"The input column $inputColName cannot be numeric.")
|
||||
case _ =>
|
||||
None // optimistic about unknown attributes
|
||||
}
|
||||
|
||||
val filteredOutputAttrNames = outputAttrNames.map { names =>
|
||||
if ($(dropLast)) {
|
||||
require(names.length > 1,
|
||||
s"The input column $inputColName should have at least two distinct values.")
|
||||
names.dropRight(1)
|
||||
} else {
|
||||
names
|
||||
}
|
||||
}
|
||||
|
||||
val outputAttrGroup = if (filteredOutputAttrNames.isDefined) {
|
||||
val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name =>
|
||||
BinaryAttribute.defaultAttr.withName(name)
|
||||
}
|
||||
new AttributeGroup($(outputCol), attrs)
|
||||
} else {
|
||||
new AttributeGroup($(outputCol))
|
||||
}
|
||||
|
||||
val outputFields = inputFields :+ outputAttrGroup.toStructField()
|
||||
val outputField = OneHotEncoderCommon.transformOutputColumnSchema(
|
||||
schema(inputColName), outputColName, $(dropLast))
|
||||
val outputFields = inputFields :+ outputField
|
||||
StructType(outputFields)
|
||||
}
|
||||
|
||||
|
@ -136,30 +100,17 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e
|
|||
// schema transformation
|
||||
val inputColName = $(inputCol)
|
||||
val outputColName = $(outputCol)
|
||||
val shouldDropLast = $(dropLast)
|
||||
var outputAttrGroup = AttributeGroup.fromStructField(
|
||||
|
||||
val outputAttrGroupFromSchema = AttributeGroup.fromStructField(
|
||||
transformSchema(dataset.schema)(outputColName))
|
||||
if (outputAttrGroup.size < 0) {
|
||||
// If the number of attributes is unknown, we check the values from the input column.
|
||||
val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0))
|
||||
.treeAggregate(0.0)(
|
||||
(m, x) => {
|
||||
assert(x <= Int.MaxValue,
|
||||
s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x")
|
||||
assert(x >= 0.0 && x == x.toInt,
|
||||
s"Values from column $inputColName must be indices, but got $x.")
|
||||
math.max(m, x)
|
||||
},
|
||||
(m0, m1) => {
|
||||
math.max(m0, m1)
|
||||
}
|
||||
).toInt + 1
|
||||
val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
|
||||
val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
|
||||
val outputAttrs: Array[Attribute] =
|
||||
filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
|
||||
outputAttrGroup = new AttributeGroup(outputColName, outputAttrs)
|
||||
|
||||
val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) {
|
||||
OneHotEncoderCommon.getOutputAttrGroupFromData(
|
||||
dataset, Seq(inputColName), Seq(outputColName), $(dropLast))(0)
|
||||
} else {
|
||||
outputAttrGroupFromSchema
|
||||
}
|
||||
|
||||
val metadata = outputAttrGroup.toMetadata()
|
||||
|
||||
// data transformation
|
||||
|
|
|
@ -0,0 +1,522 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.attribute._
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols}
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
import org.apache.spark.sql.expressions.UserDefinedFunction
|
||||
import org.apache.spark.sql.functions.{col, lit, udf}
|
||||
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
|
||||
|
||||
/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */
|
||||
private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
|
||||
with HasInputCols with HasOutputCols {
|
||||
|
||||
/**
|
||||
* Param for how to handle invalid data.
|
||||
* Options are 'keep' (invalid data presented as an extra categorical feature) or
|
||||
* 'error' (throw an error).
|
||||
* Default: "error"
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.3.0")
|
||||
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
|
||||
"How to handle invalid data " +
|
||||
"Options are 'keep' (invalid data presented as an extra categorical feature) " +
|
||||
"or error (throw an error).",
|
||||
ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids))
|
||||
|
||||
setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
|
||||
|
||||
/**
|
||||
* Whether to drop the last category in the encoded vector (default: true)
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.3.0")
|
||||
final val dropLast: BooleanParam =
|
||||
new BooleanParam(this, "dropLast", "whether to drop the last category")
|
||||
setDefault(dropLast -> true)
|
||||
|
||||
/** @group getParam */
|
||||
@Since("2.3.0")
|
||||
def getDropLast: Boolean = $(dropLast)
|
||||
|
||||
protected def validateAndTransformSchema(
|
||||
schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = {
|
||||
val inputColNames = $(inputCols)
|
||||
val outputColNames = $(outputCols)
|
||||
val existingFields = schema.fields
|
||||
|
||||
require(inputColNames.length == outputColNames.length,
|
||||
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
|
||||
s"output columns ${outputColNames.length}.")
|
||||
|
||||
// Input columns must be NumericType.
|
||||
inputColNames.foreach(SchemaUtils.checkNumericType(schema, _))
|
||||
|
||||
// Prepares output columns with proper attributes by examining input columns.
|
||||
val inputFields = $(inputCols).map(schema(_))
|
||||
|
||||
val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) =>
|
||||
OneHotEncoderCommon.transformOutputColumnSchema(
|
||||
inputField, outputColName, dropLast, keepInvalid)
|
||||
}
|
||||
outputFields.foldLeft(schema) { case (newSchema, outputField) =>
|
||||
SchemaUtils.appendColumn(newSchema, outputField)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A one-hot encoder that maps a column of category indices to a column of binary vectors, with
|
||||
* at most a single one-value per row that indicates the input category index.
|
||||
* For example with 5 categories, an input value of 2.0 would map to an output vector of
|
||||
* `[0.0, 0.0, 1.0, 0.0]`.
|
||||
* The last category is not included by default (configurable via `dropLast`),
|
||||
* because it makes the vector entries sum up to one, and hence linearly dependent.
|
||||
* So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
|
||||
*
|
||||
* @note This is different from scikit-learn's OneHotEncoder, which keeps all categories.
|
||||
* The output vectors are sparse.
|
||||
*
|
||||
* When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is
|
||||
* added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros
|
||||
* vector.
|
||||
*
|
||||
* @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols
|
||||
* come in pairs, specified by the order in the arrays, and each pair is treated independently.
|
||||
*
|
||||
* @see `StringIndexer` for converting categorical values into category indices
|
||||
*/
|
||||
@Since("2.3.0")
|
||||
class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: String)
|
||||
extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with DefaultParamsWritable {
|
||||
|
||||
@Since("2.3.0")
|
||||
def this() = this(Identifiable.randomUID("oneHotEncoder"))
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.3.0")
|
||||
def setInputCols(values: Array[String]): this.type = set(inputCols, values)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.3.0")
|
||||
def setOutputCols(values: Array[String]): this.type = set(outputCols, values)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.3.0")
|
||||
def setDropLast(value: Boolean): this.type = set(dropLast, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.3.0")
|
||||
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
|
||||
|
||||
@Since("2.3.0")
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
|
||||
validateAndTransformSchema(schema, dropLast = $(dropLast),
|
||||
keepInvalid = keepInvalid)
|
||||
}
|
||||
|
||||
@Since("2.3.0")
|
||||
override def fit(dataset: Dataset[_]): OneHotEncoderModel = {
|
||||
transformSchema(dataset.schema)
|
||||
|
||||
// Compute the plain number of categories without `handleInvalid` and
|
||||
// `dropLast` taken into account.
|
||||
val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false,
|
||||
keepInvalid = false)
|
||||
val categorySizes = new Array[Int]($(outputCols).length)
|
||||
|
||||
val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) =>
|
||||
val numOfAttrs = AttributeGroup.fromStructField(
|
||||
transformedSchema(outputColName)).size
|
||||
if (numOfAttrs < 0) {
|
||||
Some(idx)
|
||||
} else {
|
||||
categorySizes(idx) = numOfAttrs
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// Some input columns don't have attributes or their attributes don't have necessary info.
|
||||
// We need to scan the data to get the number of values for each column.
|
||||
if (columnToScanIndices.length > 0) {
|
||||
val inputColNames = columnToScanIndices.map($(inputCols)(_))
|
||||
val outputColNames = columnToScanIndices.map($(outputCols)(_))
|
||||
|
||||
// When fitting data, we want the plain number of categories without `handleInvalid` and
|
||||
// `dropLast` taken into account.
|
||||
val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData(
|
||||
dataset, inputColNames, outputColNames, dropLast = false)
|
||||
attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) =>
|
||||
categorySizes(idx) = attrGroup.size
|
||||
}
|
||||
}
|
||||
|
||||
val model = new OneHotEncoderModel(uid, categorySizes).setParent(this)
|
||||
copyValues(model)
|
||||
}
|
||||
|
||||
@Since("2.3.0")
|
||||
override def copy(extra: ParamMap): OneHotEncoderEstimator = defaultCopy(extra)
|
||||
}
|
||||
|
||||
@Since("2.3.0")
|
||||
object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimator] {
|
||||
|
||||
private[feature] val KEEP_INVALID: String = "keep"
|
||||
private[feature] val ERROR_INVALID: String = "error"
|
||||
private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID)
|
||||
|
||||
@Since("2.3.0")
|
||||
override def load(path: String): OneHotEncoderEstimator = super.load(path)
|
||||
}
|
||||
|
||||
@Since("2.3.0")
|
||||
class OneHotEncoderModel private[ml] (
|
||||
@Since("2.3.0") override val uid: String,
|
||||
@Since("2.3.0") val categorySizes: Array[Int])
|
||||
extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable {
|
||||
|
||||
import OneHotEncoderModel._
|
||||
|
||||
// Returns the category size for a given index with `dropLast` and `handleInvalid`
|
||||
// taken into account.
|
||||
private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = {
|
||||
val dropLast = getDropLast
|
||||
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
|
||||
|
||||
if (!dropLast && keepInvalid) {
|
||||
// When `handleInvalid` is "keep", an extra category is added as last category
|
||||
// for invalid data.
|
||||
orgCategorySize + 1
|
||||
} else if (dropLast && !keepInvalid) {
|
||||
// When `dropLast` is true, the last category is removed.
|
||||
orgCategorySize - 1
|
||||
} else {
|
||||
// When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid
|
||||
// data is removed. Thus, it is the same as the plain number of categories.
|
||||
orgCategorySize
|
||||
}
|
||||
}
|
||||
|
||||
private def encoder: UserDefinedFunction = {
|
||||
val oneValue = Array(1.0)
|
||||
val emptyValues = Array.empty[Double]
|
||||
val emptyIndices = Array.empty[Int]
|
||||
val dropLast = getDropLast
|
||||
val handleInvalid = getHandleInvalid
|
||||
val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID
|
||||
|
||||
// The udf performed on input data. The first parameter is the input value. The second
|
||||
// parameter is the index of input.
|
||||
udf { (label: Double, idx: Int) =>
|
||||
val plainNumCategories = categorySizes(idx)
|
||||
val size = configedCategorySize(plainNumCategories, idx)
|
||||
|
||||
if (label < 0) {
|
||||
throw new SparkException(s"Negative value: $label. Input can't be negative.")
|
||||
} else if (label == size && dropLast && !keepInvalid) {
|
||||
// When `dropLast` is true and `handleInvalid` is not "keep",
|
||||
// the last category is removed.
|
||||
Vectors.sparse(size, emptyIndices, emptyValues)
|
||||
} else if (label >= plainNumCategories && keepInvalid) {
|
||||
// When `handleInvalid` is "keep", encodes invalid data to last category (and removed
|
||||
// if `dropLast` is true)
|
||||
if (dropLast) {
|
||||
Vectors.sparse(size, emptyIndices, emptyValues)
|
||||
} else {
|
||||
Vectors.sparse(size, Array(size - 1), oneValue)
|
||||
}
|
||||
} else if (label < plainNumCategories) {
|
||||
Vectors.sparse(size, Array(label.toInt), oneValue)
|
||||
} else {
|
||||
assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID)
|
||||
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
|
||||
s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.3.0")
|
||||
def setInputCols(values: Array[String]): this.type = set(inputCols, values)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.3.0")
|
||||
def setOutputCols(values: Array[String]): this.type = set(outputCols, values)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.3.0")
|
||||
def setDropLast(value: Boolean): this.type = set(dropLast, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.3.0")
|
||||
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
|
||||
|
||||
@Since("2.3.0")
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
val inputColNames = $(inputCols)
|
||||
val outputColNames = $(outputCols)
|
||||
|
||||
require(inputColNames.length == categorySizes.length,
|
||||
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
|
||||
s"features ${categorySizes.length} during fitting.")
|
||||
|
||||
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
|
||||
val transformedSchema = validateAndTransformSchema(schema, dropLast = $(dropLast),
|
||||
keepInvalid = keepInvalid)
|
||||
verifyNumOfValues(transformedSchema)
|
||||
}
|
||||
|
||||
/**
|
||||
* If the metadata of input columns also specifies the number of categories, we need to
|
||||
* compare with expected category number with `handleInvalid` and `dropLast` taken into
|
||||
* account. Mismatched numbers will cause exception.
|
||||
*/
|
||||
private def verifyNumOfValues(schema: StructType): StructType = {
|
||||
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
|
||||
val inputColName = $(inputCols)(idx)
|
||||
val attrGroup = AttributeGroup.fromStructField(schema(outputColName))
|
||||
|
||||
// If the input metadata specifies number of category for output column,
|
||||
// comparing with expected category number with `handleInvalid` and
|
||||
// `dropLast` taken into account.
|
||||
if (attrGroup.attributes.nonEmpty) {
|
||||
val numCategories = configedCategorySize(categorySizes(idx), idx)
|
||||
require(attrGroup.size == numCategories, "OneHotEncoderModel expected " +
|
||||
s"$numCategories categorical values for input column ${inputColName}, " +
|
||||
s"but the input column had metadata specifying ${attrGroup.size} values.")
|
||||
}
|
||||
}
|
||||
schema
|
||||
}
|
||||
|
||||
@Since("2.3.0")
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
val transformedSchema = transformSchema(dataset.schema, logging = true)
|
||||
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
|
||||
|
||||
val encodedColumns = (0 until $(inputCols).length).map { idx =>
|
||||
val inputColName = $(inputCols)(idx)
|
||||
val outputColName = $(outputCols)(idx)
|
||||
|
||||
val outputAttrGroupFromSchema =
|
||||
AttributeGroup.fromStructField(transformedSchema(outputColName))
|
||||
|
||||
val metadata = if (outputAttrGroupFromSchema.size < 0) {
|
||||
OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName,
|
||||
categorySizes(idx), $(dropLast), keepInvalid).toMetadata()
|
||||
} else {
|
||||
outputAttrGroupFromSchema.toMetadata()
|
||||
}
|
||||
|
||||
encoder(col(inputColName).cast(DoubleType), lit(idx))
|
||||
.as(outputColName, metadata)
|
||||
}
|
||||
dataset.withColumns($(outputCols), encodedColumns)
|
||||
}
|
||||
|
||||
@Since("2.3.0")
|
||||
override def copy(extra: ParamMap): OneHotEncoderModel = {
|
||||
val copied = new OneHotEncoderModel(uid, categorySizes)
|
||||
copyValues(copied, extra).setParent(parent)
|
||||
}
|
||||
|
||||
@Since("2.3.0")
|
||||
override def write: MLWriter = new OneHotEncoderModelWriter(this)
|
||||
}
|
||||
|
||||
@Since("2.3.0")
|
||||
object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] {
|
||||
|
||||
private[OneHotEncoderModel]
|
||||
class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends MLWriter {
|
||||
|
||||
private case class Data(categorySizes: Array[Int])
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
DefaultParamsWriter.saveMetadata(instance, path, sc)
|
||||
val data = Data(instance.categorySizes)
|
||||
val dataPath = new Path(path, "data").toString
|
||||
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
|
||||
}
|
||||
}
|
||||
|
||||
private class OneHotEncoderModelReader extends MLReader[OneHotEncoderModel] {
|
||||
|
||||
private val className = classOf[OneHotEncoderModel].getName
|
||||
|
||||
override def load(path: String): OneHotEncoderModel = {
|
||||
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val data = sparkSession.read.parquet(dataPath)
|
||||
.select("categorySizes")
|
||||
.head()
|
||||
val categorySizes = data.getAs[Seq[Int]](0).toArray
|
||||
val model = new OneHotEncoderModel(metadata.uid, categorySizes)
|
||||
DefaultParamsReader.getAndSetParams(model, metadata)
|
||||
model
|
||||
}
|
||||
}
|
||||
|
||||
@Since("2.3.0")
|
||||
override def read: MLReader[OneHotEncoderModel] = new OneHotEncoderModelReader
|
||||
|
||||
@Since("2.3.0")
|
||||
override def load(path: String): OneHotEncoderModel = super.load(path)
|
||||
}
|
||||
|
||||
/**
|
||||
* Provides some helper methods used by both `OneHotEncoder` and `OneHotEncoderEstimator`.
|
||||
*/
|
||||
private[feature] object OneHotEncoderCommon {
|
||||
|
||||
private def genOutputAttrNames(inputCol: StructField): Option[Array[String]] = {
|
||||
val inputAttr = Attribute.fromStructField(inputCol)
|
||||
inputAttr match {
|
||||
case nominal: NominalAttribute =>
|
||||
if (nominal.values.isDefined) {
|
||||
nominal.values
|
||||
} else if (nominal.numValues.isDefined) {
|
||||
nominal.numValues.map(n => Array.tabulate(n)(_.toString))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
case binary: BinaryAttribute =>
|
||||
if (binary.values.isDefined) {
|
||||
binary.values
|
||||
} else {
|
||||
Some(Array.tabulate(2)(_.toString))
|
||||
}
|
||||
case _: NumericAttribute =>
|
||||
throw new RuntimeException(
|
||||
s"The input column ${inputCol.name} cannot be continuous-value.")
|
||||
case _ =>
|
||||
None // optimistic about unknown attributes
|
||||
}
|
||||
}
|
||||
|
||||
/** Creates an `AttributeGroup` filled by the `BinaryAttribute` named as required. */
|
||||
private def genOutputAttrGroup(
|
||||
outputAttrNames: Option[Array[String]],
|
||||
outputColName: String): AttributeGroup = {
|
||||
outputAttrNames.map { attrNames =>
|
||||
val attrs: Array[Attribute] = attrNames.map { name =>
|
||||
BinaryAttribute.defaultAttr.withName(name)
|
||||
}
|
||||
new AttributeGroup(outputColName, attrs)
|
||||
}.getOrElse{
|
||||
new AttributeGroup(outputColName)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepares the `StructField` with proper metadata for `OneHotEncoder`'s output column.
|
||||
*/
|
||||
def transformOutputColumnSchema(
|
||||
inputCol: StructField,
|
||||
outputColName: String,
|
||||
dropLast: Boolean,
|
||||
keepInvalid: Boolean = false): StructField = {
|
||||
val outputAttrNames = genOutputAttrNames(inputCol)
|
||||
val filteredOutputAttrNames = outputAttrNames.map { names =>
|
||||
if (dropLast && !keepInvalid) {
|
||||
require(names.length > 1,
|
||||
s"The input column ${inputCol.name} should have at least two distinct values.")
|
||||
names.dropRight(1)
|
||||
} else if (!dropLast && keepInvalid) {
|
||||
names ++ Seq("invalidValues")
|
||||
} else {
|
||||
names
|
||||
}
|
||||
}
|
||||
|
||||
genOutputAttrGroup(filteredOutputAttrNames, outputColName).toStructField()
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is called when we want to generate `AttributeGroup` from actual data for
|
||||
* one-hot encoder.
|
||||
*/
|
||||
def getOutputAttrGroupFromData(
|
||||
dataset: Dataset[_],
|
||||
inputColNames: Seq[String],
|
||||
outputColNames: Seq[String],
|
||||
dropLast: Boolean): Seq[AttributeGroup] = {
|
||||
// The RDD approach has advantage of early-stop if any values are invalid. It seems that
|
||||
// DataFrame ops don't have equivalent functions.
|
||||
val columns = inputColNames.map { inputColName =>
|
||||
col(inputColName).cast(DoubleType)
|
||||
}
|
||||
val numOfColumns = columns.length
|
||||
|
||||
val numAttrsArray = dataset.select(columns: _*).rdd.map { row =>
|
||||
(0 until numOfColumns).map(idx => row.getDouble(idx)).toArray
|
||||
}.treeAggregate(new Array[Double](numOfColumns))(
|
||||
(maxValues, curValues) => {
|
||||
(0 until numOfColumns).foreach { idx =>
|
||||
val x = curValues(idx)
|
||||
assert(x <= Int.MaxValue,
|
||||
s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x.")
|
||||
assert(x >= 0.0 && x == x.toInt,
|
||||
s"Values from column ${inputColNames(idx)} must be indices, but got $x.")
|
||||
maxValues(idx) = math.max(maxValues(idx), x)
|
||||
}
|
||||
maxValues
|
||||
},
|
||||
(m0, m1) => {
|
||||
(0 until numOfColumns).foreach { idx =>
|
||||
m0(idx) = math.max(m0(idx), m1(idx))
|
||||
}
|
||||
m0
|
||||
}
|
||||
).map(_.toInt + 1)
|
||||
|
||||
outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) =>
|
||||
createAttrGroupForAttrNames(outputColName, numAttrs, dropLast, keepInvalid = false)
|
||||
}
|
||||
}
|
||||
|
||||
/** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */
|
||||
def createAttrGroupForAttrNames(
|
||||
outputColName: String,
|
||||
numAttrs: Int,
|
||||
dropLast: Boolean,
|
||||
keepInvalid: Boolean): AttributeGroup = {
|
||||
val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
|
||||
val filtered = if (dropLast && !keepInvalid) {
|
||||
outputAttrNames.dropRight(1)
|
||||
} else if (!dropLast && keepInvalid) {
|
||||
outputAttrNames ++ Seq("invalidValues")
|
||||
} else {
|
||||
outputAttrNames
|
||||
}
|
||||
genOutputAttrGroup(Some(filtered), outputColName)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,421 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.apache.spark.{SparkException, SparkFunSuite}
|
||||
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.functions.col
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class OneHotEncoderEstimatorSuite
|
||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
test("params") {
|
||||
ParamsSuite.checkParams(new OneHotEncoderEstimator)
|
||||
}
|
||||
|
||||
test("OneHotEncoderEstimator dropLast = false") {
|
||||
val data = Seq(
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
|
||||
Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))),
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))))
|
||||
|
||||
val schema = StructType(Array(
|
||||
StructField("input", DoubleType),
|
||||
StructField("expected", new VectorUDT)))
|
||||
|
||||
val df = spark.createDataFrame(sc.parallelize(data), schema)
|
||||
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("input"))
|
||||
.setOutputCols(Array("output"))
|
||||
assert(encoder.getDropLast === true)
|
||||
encoder.setDropLast(false)
|
||||
assert(encoder.getDropLast === false)
|
||||
|
||||
val model = encoder.fit(df)
|
||||
val encoded = model.transform(df)
|
||||
encoded.select("output", "expected").rdd.map { r =>
|
||||
(r.getAs[Vector](0), r.getAs[Vector](1))
|
||||
}.collect().foreach { case (vec1, vec2) =>
|
||||
assert(vec1 === vec2)
|
||||
}
|
||||
}
|
||||
|
||||
test("OneHotEncoderEstimator dropLast = true") {
|
||||
val data = Seq(
|
||||
Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))),
|
||||
Row(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(2, Seq())),
|
||||
Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))),
|
||||
Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(2, Seq())))
|
||||
|
||||
val schema = StructType(Array(
|
||||
StructField("input", DoubleType),
|
||||
StructField("expected", new VectorUDT)))
|
||||
|
||||
val df = spark.createDataFrame(sc.parallelize(data), schema)
|
||||
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("input"))
|
||||
.setOutputCols(Array("output"))
|
||||
|
||||
val model = encoder.fit(df)
|
||||
val encoded = model.transform(df)
|
||||
encoded.select("output", "expected").rdd.map { r =>
|
||||
(r.getAs[Vector](0), r.getAs[Vector](1))
|
||||
}.collect().foreach { case (vec1, vec2) =>
|
||||
assert(vec1 === vec2)
|
||||
}
|
||||
}
|
||||
|
||||
test("input column with ML attribute") {
|
||||
val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
|
||||
val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size")
|
||||
.select(col("size").as("size", attr.toMetadata()))
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("size"))
|
||||
.setOutputCols(Array("encoded"))
|
||||
val model = encoder.fit(df)
|
||||
val output = model.transform(df)
|
||||
val group = AttributeGroup.fromStructField(output.schema("encoded"))
|
||||
assert(group.size === 2)
|
||||
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
|
||||
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
|
||||
}
|
||||
|
||||
test("input column without ML attribute") {
|
||||
val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index")
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("index"))
|
||||
.setOutputCols(Array("encoded"))
|
||||
val model = encoder.fit(df)
|
||||
val output = model.transform(df)
|
||||
val group = AttributeGroup.fromStructField(output.schema("encoded"))
|
||||
assert(group.size === 2)
|
||||
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
|
||||
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
|
||||
}
|
||||
|
||||
test("read/write") {
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("index"))
|
||||
.setOutputCols(Array("encoded"))
|
||||
testDefaultReadWrite(encoder)
|
||||
}
|
||||
|
||||
test("OneHotEncoderModel read/write") {
|
||||
val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3))
|
||||
val newInstance = testDefaultReadWrite(instance)
|
||||
assert(newInstance.categorySizes === instance.categorySizes)
|
||||
}
|
||||
|
||||
test("OneHotEncoderEstimator with varying types") {
|
||||
val data = Seq(
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
|
||||
Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))),
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))))
|
||||
|
||||
val schema = StructType(Array(
|
||||
StructField("input", DoubleType),
|
||||
StructField("expected", new VectorUDT)))
|
||||
|
||||
val df = spark.createDataFrame(sc.parallelize(data), schema)
|
||||
|
||||
val dfWithTypes = df
|
||||
.withColumn("shortInput", df("input").cast(ShortType))
|
||||
.withColumn("longInput", df("input").cast(LongType))
|
||||
.withColumn("intInput", df("input").cast(IntegerType))
|
||||
.withColumn("floatInput", df("input").cast(FloatType))
|
||||
.withColumn("decimalInput", df("input").cast(DecimalType(10, 0)))
|
||||
|
||||
val cols = Array("input", "shortInput", "longInput", "intInput",
|
||||
"floatInput", "decimalInput")
|
||||
for (col <- cols) {
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array(col))
|
||||
.setOutputCols(Array("output"))
|
||||
.setDropLast(false)
|
||||
|
||||
val model = encoder.fit(dfWithTypes)
|
||||
val encoded = model.transform(dfWithTypes)
|
||||
|
||||
encoded.select("output", "expected").rdd.map { r =>
|
||||
(r.getAs[Vector](0), r.getAs[Vector](1))
|
||||
}.collect().foreach { case (vec1, vec2) =>
|
||||
assert(vec1 === vec2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("OneHotEncoderEstimator: encoding multiple columns and dropLast = false") {
|
||||
val data = Seq(
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))),
|
||||
Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))),
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))),
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))))
|
||||
|
||||
val schema = StructType(Array(
|
||||
StructField("input1", DoubleType),
|
||||
StructField("expected1", new VectorUDT),
|
||||
StructField("input2", DoubleType),
|
||||
StructField("expected2", new VectorUDT)))
|
||||
|
||||
val df = spark.createDataFrame(sc.parallelize(data), schema)
|
||||
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("input1", "input2"))
|
||||
.setOutputCols(Array("output1", "output2"))
|
||||
assert(encoder.getDropLast === true)
|
||||
encoder.setDropLast(false)
|
||||
assert(encoder.getDropLast === false)
|
||||
|
||||
val model = encoder.fit(df)
|
||||
val encoded = model.transform(df)
|
||||
encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r =>
|
||||
(r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3))
|
||||
}.collect().foreach { case (vec1, vec2, vec3, vec4) =>
|
||||
assert(vec1 === vec2)
|
||||
assert(vec3 === vec4)
|
||||
}
|
||||
}
|
||||
|
||||
test("OneHotEncoderEstimator: encoding multiple columns and dropLast = true") {
|
||||
val data = Seq(
|
||||
Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))),
|
||||
Row(1.0, Vectors.sparse(2, Seq((1, 1.0))), 3.0, Vectors.sparse(3, Seq())),
|
||||
Row(2.0, Vectors.sparse(2, Seq()), 0.0, Vectors.sparse(3, Seq((0, 1.0)))),
|
||||
Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 1.0, Vectors.sparse(3, Seq((1, 1.0)))),
|
||||
Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 0.0, Vectors.sparse(3, Seq((0, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(2, Seq()), 2.0, Vectors.sparse(3, Seq((2, 1.0)))))
|
||||
|
||||
val schema = StructType(Array(
|
||||
StructField("input1", DoubleType),
|
||||
StructField("expected1", new VectorUDT),
|
||||
StructField("input2", DoubleType),
|
||||
StructField("expected2", new VectorUDT)))
|
||||
|
||||
val df = spark.createDataFrame(sc.parallelize(data), schema)
|
||||
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("input1", "input2"))
|
||||
.setOutputCols(Array("output1", "output2"))
|
||||
|
||||
val model = encoder.fit(df)
|
||||
val encoded = model.transform(df)
|
||||
encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r =>
|
||||
(r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3))
|
||||
}.collect().foreach { case (vec1, vec2, vec3, vec4) =>
|
||||
assert(vec1 === vec2)
|
||||
assert(vec3 === vec4)
|
||||
}
|
||||
}
|
||||
|
||||
test("Throw error on invalid values") {
|
||||
val trainingData = Seq((0, 0), (1, 1), (2, 2))
|
||||
val trainingDF = trainingData.toDF("id", "a")
|
||||
val testData = Seq((0, 0), (1, 2), (1, 3))
|
||||
val testDF = testData.toDF("id", "a")
|
||||
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("a"))
|
||||
.setOutputCols(Array("encoded"))
|
||||
|
||||
val model = encoder.fit(trainingDF)
|
||||
val err = intercept[SparkException] {
|
||||
model.transform(testDF).show
|
||||
}
|
||||
err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
|
||||
}
|
||||
|
||||
test("Can't transform on negative input") {
|
||||
val trainingDF = Seq((0, 0), (1, 1), (2, 2)).toDF("a", "b")
|
||||
val testDF = Seq((0, 0), (-1, 2), (1, 3)).toDF("a", "b")
|
||||
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("a"))
|
||||
.setOutputCols(Array("encoded"))
|
||||
|
||||
val model = encoder.fit(trainingDF)
|
||||
val err = intercept[SparkException] {
|
||||
model.transform(testDF).collect()
|
||||
}
|
||||
err.getMessage.contains("Negative value: -1.0. Input can't be negative")
|
||||
}
|
||||
|
||||
test("Keep on invalid values: dropLast = false") {
|
||||
val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input")
|
||||
|
||||
val testData = Seq(
|
||||
Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))),
|
||||
Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))),
|
||||
Row(3.0, Vectors.sparse(4, Seq((3, 1.0)))))
|
||||
|
||||
val schema = StructType(Array(
|
||||
StructField("input", DoubleType),
|
||||
StructField("expected", new VectorUDT)))
|
||||
|
||||
val testDF = spark.createDataFrame(sc.parallelize(testData), schema)
|
||||
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("input"))
|
||||
.setOutputCols(Array("output"))
|
||||
.setHandleInvalid("keep")
|
||||
.setDropLast(false)
|
||||
|
||||
val model = encoder.fit(trainingDF)
|
||||
val encoded = model.transform(testDF)
|
||||
encoded.select("output", "expected").rdd.map { r =>
|
||||
(r.getAs[Vector](0), r.getAs[Vector](1))
|
||||
}.collect().foreach { case (vec1, vec2) =>
|
||||
assert(vec1 === vec2)
|
||||
}
|
||||
}
|
||||
|
||||
test("Keep on invalid values: dropLast = true") {
|
||||
val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input")
|
||||
|
||||
val testData = Seq(
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
|
||||
Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
|
||||
Row(3.0, Vectors.sparse(3, Seq())))
|
||||
|
||||
val schema = StructType(Array(
|
||||
StructField("input", DoubleType),
|
||||
StructField("expected", new VectorUDT)))
|
||||
|
||||
val testDF = spark.createDataFrame(sc.parallelize(testData), schema)
|
||||
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("input"))
|
||||
.setOutputCols(Array("output"))
|
||||
.setHandleInvalid("keep")
|
||||
.setDropLast(true)
|
||||
|
||||
val model = encoder.fit(trainingDF)
|
||||
val encoded = model.transform(testDF)
|
||||
encoded.select("output", "expected").rdd.map { r =>
|
||||
(r.getAs[Vector](0), r.getAs[Vector](1))
|
||||
}.collect().foreach { case (vec1, vec2) =>
|
||||
assert(vec1 === vec2)
|
||||
}
|
||||
}
|
||||
|
||||
test("OneHotEncoderModel changes dropLast") {
|
||||
val data = Seq(
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))),
|
||||
Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), Vectors.sparse(2, Seq((1, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())),
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))),
|
||||
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))),
|
||||
Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())))
|
||||
|
||||
val schema = StructType(Array(
|
||||
StructField("input", DoubleType),
|
||||
StructField("expected1", new VectorUDT),
|
||||
StructField("expected2", new VectorUDT)))
|
||||
|
||||
val df = spark.createDataFrame(sc.parallelize(data), schema)
|
||||
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("input"))
|
||||
.setOutputCols(Array("output"))
|
||||
|
||||
val model = encoder.fit(df)
|
||||
|
||||
model.setDropLast(false)
|
||||
val encoded1 = model.transform(df)
|
||||
encoded1.select("output", "expected1").rdd.map { r =>
|
||||
(r.getAs[Vector](0), r.getAs[Vector](1))
|
||||
}.collect().foreach { case (vec1, vec2) =>
|
||||
assert(vec1 === vec2)
|
||||
}
|
||||
|
||||
model.setDropLast(true)
|
||||
val encoded2 = model.transform(df)
|
||||
encoded2.select("output", "expected2").rdd.map { r =>
|
||||
(r.getAs[Vector](0), r.getAs[Vector](1))
|
||||
}.collect().foreach { case (vec1, vec2) =>
|
||||
assert(vec1 === vec2)
|
||||
}
|
||||
}
|
||||
|
||||
test("OneHotEncoderModel changes handleInvalid") {
|
||||
val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input")
|
||||
|
||||
val testData = Seq(
|
||||
Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))),
|
||||
Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))),
|
||||
Row(3.0, Vectors.sparse(4, Seq((3, 1.0)))))
|
||||
|
||||
val schema = StructType(Array(
|
||||
StructField("input", DoubleType),
|
||||
StructField("expected", new VectorUDT)))
|
||||
|
||||
val testDF = spark.createDataFrame(sc.parallelize(testData), schema)
|
||||
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("input"))
|
||||
.setOutputCols(Array("output"))
|
||||
|
||||
val model = encoder.fit(trainingDF)
|
||||
model.setHandleInvalid("error")
|
||||
|
||||
val err = intercept[SparkException] {
|
||||
model.transform(testDF).collect()
|
||||
}
|
||||
err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
|
||||
|
||||
model.setHandleInvalid("keep")
|
||||
model.transform(testDF).collect()
|
||||
}
|
||||
|
||||
test("Transforming on mismatched attributes") {
|
||||
val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
|
||||
val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size")
|
||||
.select(col("size").as("size", attr.toMetadata()))
|
||||
val encoder = new OneHotEncoderEstimator()
|
||||
.setInputCols(Array("size"))
|
||||
.setOutputCols(Array("encoded"))
|
||||
val model = encoder.fit(df)
|
||||
|
||||
val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large")
|
||||
val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size")
|
||||
.select(col("size").as("size", testAttr.toMetadata()))
|
||||
val err = intercept[Exception] {
|
||||
model.transform(testDF).collect()
|
||||
}
|
||||
err.getMessage.contains("OneHotEncoderModel expected 2 categorical values")
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue