[SPARK-23690][ML] Add handleinvalid to VectorAssembler
## What changes were proposed in this pull request? Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found. ## How was this patch tested? Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases. Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com> Author: Bago Amirbekian <bago@databricks.com> Author: Yogesh Garg <1059168+yogeshg@users.noreply.github.com> Closes #20829 from yogeshg/rformula_handleinvalid.
This commit is contained in:
parent
28ea4e3142
commit
a1351828d3
|
@ -234,7 +234,7 @@ class StringIndexerModel (
|
||||||
val metadata = NominalAttribute.defaultAttr
|
val metadata = NominalAttribute.defaultAttr
|
||||||
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
|
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
|
||||||
// If we are skipping invalid records, filter them out.
|
// If we are skipping invalid records, filter them out.
|
||||||
val (filteredDataset, keepInvalid) = getHandleInvalid match {
|
val (filteredDataset, keepInvalid) = $(handleInvalid) match {
|
||||||
case StringIndexer.SKIP_INVALID =>
|
case StringIndexer.SKIP_INVALID =>
|
||||||
val filterer = udf { label: String =>
|
val filterer = udf { label: String =>
|
||||||
labelToIndex.contains(label)
|
labelToIndex.contains(label)
|
||||||
|
|
|
@ -17,14 +17,17 @@
|
||||||
|
|
||||||
package org.apache.spark.ml.feature
|
package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import scala.collection.mutable.ArrayBuilder
|
import java.util.NoSuchElementException
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
import scala.language.existentials
|
||||||
|
|
||||||
import org.apache.spark.SparkException
|
import org.apache.spark.SparkException
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.ml.Transformer
|
import org.apache.spark.ml.Transformer
|
||||||
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
|
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
|
||||||
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
|
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
|
@ -33,10 +36,14 @@ import org.apache.spark.sql.types._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A feature transformer that merges multiple columns into a vector column.
|
* A feature transformer that merges multiple columns into a vector column.
|
||||||
|
*
|
||||||
|
* This requires one pass over the entire dataset. In case we need to infer column lengths from the
|
||||||
|
* data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
|
||||||
*/
|
*/
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
|
extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
|
||||||
|
with DefaultParamsWritable {
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
def this() = this(Identifiable.randomUID("vecAssembler"))
|
def this() = this(Identifiable.randomUID("vecAssembler"))
|
||||||
|
@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
|
/** @group setParam */
|
||||||
|
@Since("2.4.0")
|
||||||
|
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
|
||||||
|
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
|
||||||
|
* output). Column lengths are taken from the size of ML Attribute Group, which can be set using
|
||||||
|
* `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
|
||||||
|
* from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
|
||||||
|
* Default: "error"
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
@Since("2.4.0")
|
||||||
|
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
|
||||||
|
"""Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
|
||||||
|
|invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
|
||||||
|
|output). Column lengths are taken from the size of ML Attribute Group, which can be set using
|
||||||
|
|`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
|
||||||
|
|from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
|
||||||
|
|""".stripMargin.replaceAll("\n", " "),
|
||||||
|
ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
|
||||||
|
|
||||||
|
setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
// Schema transformation.
|
// Schema transformation.
|
||||||
val schema = dataset.schema
|
val schema = dataset.schema
|
||||||
lazy val first = dataset.toDF.first()
|
|
||||||
val attrs = $(inputCols).flatMap { c =>
|
val vectorCols = $(inputCols).filter { c =>
|
||||||
|
schema(c).dataType match {
|
||||||
|
case _: VectorUDT => true
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))
|
||||||
|
|
||||||
|
val featureAttributesMap = $(inputCols).map { c =>
|
||||||
val field = schema(c)
|
val field = schema(c)
|
||||||
val index = schema.fieldIndex(c)
|
|
||||||
field.dataType match {
|
field.dataType match {
|
||||||
case DoubleType =>
|
case DoubleType =>
|
||||||
val attr = Attribute.fromStructField(field)
|
val attribute = Attribute.fromStructField(field)
|
||||||
// If the input column doesn't have ML attribute, assume numeric.
|
attribute match {
|
||||||
if (attr == UnresolvedAttribute) {
|
case UnresolvedAttribute =>
|
||||||
Some(NumericAttribute.defaultAttr.withName(c))
|
Seq(NumericAttribute.defaultAttr.withName(c))
|
||||||
} else {
|
case _ =>
|
||||||
Some(attr.withName(c))
|
Seq(attribute.withName(c))
|
||||||
}
|
}
|
||||||
case _: NumericType | BooleanType =>
|
case _: NumericType | BooleanType =>
|
||||||
// If the input column type is a compatible scalar type, assume numeric.
|
// If the input column type is a compatible scalar type, assume numeric.
|
||||||
Some(NumericAttribute.defaultAttr.withName(c))
|
Seq(NumericAttribute.defaultAttr.withName(c))
|
||||||
case _: VectorUDT =>
|
case _: VectorUDT =>
|
||||||
val group = AttributeGroup.fromStructField(field)
|
val attributeGroup = AttributeGroup.fromStructField(field)
|
||||||
if (group.attributes.isDefined) {
|
if (attributeGroup.attributes.isDefined) {
|
||||||
// If attributes are defined, copy them with updated names.
|
attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
|
||||||
group.attributes.get.zipWithIndex.map { case (attr, i) =>
|
|
||||||
if (attr.name.isDefined) {
|
if (attr.name.isDefined) {
|
||||||
// TODO: Define a rigorous naming scheme.
|
// TODO: Define a rigorous naming scheme.
|
||||||
attr.withName(c + "_" + attr.name.get)
|
attr.withName(c + "_" + attr.name.get)
|
||||||
|
@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
|
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
|
||||||
// from metadata, check the first row.
|
// from metadata, check the first row.
|
||||||
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
|
(0 until vectorColsLengths(c)).map { i =>
|
||||||
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
|
NumericAttribute.defaultAttr.withName(c + "_" + i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case otherType =>
|
case otherType =>
|
||||||
throw new SparkException(s"VectorAssembler does not support the $otherType type")
|
throw new SparkException(s"VectorAssembler does not support the $otherType type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
|
val featureAttributes = featureAttributesMap.flatten[Attribute].toArray
|
||||||
|
val lengths = featureAttributesMap.map(a => a.length).toArray
|
||||||
|
val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata()
|
||||||
|
val (filteredDataset, keepInvalid) = $(handleInvalid) match {
|
||||||
|
case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false)
|
||||||
|
case VectorAssembler.KEEP_INVALID => (dataset, true)
|
||||||
|
case VectorAssembler.ERROR_INVALID => (dataset, false)
|
||||||
|
}
|
||||||
// Data transformation.
|
// Data transformation.
|
||||||
val assembleFunc = udf { r: Row =>
|
val assembleFunc = udf { r: Row =>
|
||||||
VectorAssembler.assemble(r.toSeq: _*)
|
VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
|
||||||
}.asNondeterministic()
|
}.asNondeterministic()
|
||||||
val args = $(inputCols).map { c =>
|
val args = $(inputCols).map { c =>
|
||||||
schema(c).dataType match {
|
schema(c).dataType match {
|
||||||
|
@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
|
filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
|
@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
|
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
|
||||||
|
|
||||||
|
private[feature] val SKIP_INVALID: String = "skip"
|
||||||
|
private[feature] val ERROR_INVALID: String = "error"
|
||||||
|
private[feature] val KEEP_INVALID: String = "keep"
|
||||||
|
private[feature] val supportedHandleInvalids: Array[String] =
|
||||||
|
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Infers lengths of vector columns from the first row of the dataset
|
||||||
|
* @param dataset the dataset
|
||||||
|
* @param columns name of vector columns whose lengths need to be inferred
|
||||||
|
* @return map of column names to lengths
|
||||||
|
*/
|
||||||
|
private[feature] def getVectorLengthsFromFirstRow(
|
||||||
|
dataset: Dataset[_],
|
||||||
|
columns: Seq[String]): Map[String, Int] = {
|
||||||
|
try {
|
||||||
|
val first_row = dataset.toDF().select(columns.map(col): _*).first()
|
||||||
|
columns.zip(first_row.toSeq).map {
|
||||||
|
case (c, x) => c -> x.asInstanceOf[Vector].size
|
||||||
|
}.toMap
|
||||||
|
} catch {
|
||||||
|
case e: NullPointerException => throw new NullPointerException(
|
||||||
|
s"""Encountered null value while inferring lengths from the first row. Consider using
|
||||||
|
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
|
||||||
|
.stripMargin.replaceAll("\n", " ") + e.toString)
|
||||||
|
case e: NoSuchElementException => throw new NoSuchElementException(
|
||||||
|
s"""Encountered empty dataframe while inferring lengths from the first row. Consider using
|
||||||
|
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
|
||||||
|
.stripMargin.replaceAll("\n", " ") + e.toString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[feature] def getLengths(
|
||||||
|
dataset: Dataset[_],
|
||||||
|
columns: Seq[String],
|
||||||
|
handleInvalid: String): Map[String, Int] = {
|
||||||
|
val groupSizes = columns.map { c =>
|
||||||
|
c -> AttributeGroup.fromStructField(dataset.schema(c)).size
|
||||||
|
}.toMap
|
||||||
|
val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
|
||||||
|
val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
|
||||||
|
case (true, VectorAssembler.ERROR_INVALID) =>
|
||||||
|
getVectorLengthsFromFirstRow(dataset, missingColumns)
|
||||||
|
case (true, VectorAssembler.SKIP_INVALID) =>
|
||||||
|
getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
|
||||||
|
case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
|
||||||
|
s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
|
||||||
|
|to add metadata for columns: ${columns.mkString("[", ", ", "]")}."""
|
||||||
|
.stripMargin.replaceAll("\n", " "))
|
||||||
|
case (_, _) => Map.empty
|
||||||
|
}
|
||||||
|
groupSizes ++ firstSizes
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def load(path: String): VectorAssembler = super.load(path)
|
override def load(path: String): VectorAssembler = super.load(path)
|
||||||
|
|
||||||
private[feature] def assemble(vv: Any*): Vector = {
|
/**
|
||||||
val indices = ArrayBuilder.make[Int]
|
* Returns a function that has the required information to assemble each row.
|
||||||
val values = ArrayBuilder.make[Double]
|
* @param lengths an array of lengths of input columns, whose size should be equal to the number
|
||||||
var cur = 0
|
* of cells in the row (vv)
|
||||||
|
* @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
|
||||||
|
* @return a udf that can be applied on each row
|
||||||
|
*/
|
||||||
|
private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
|
||||||
|
val indices = mutable.ArrayBuilder.make[Int]
|
||||||
|
val values = mutable.ArrayBuilder.make[Double]
|
||||||
|
var featureIndex = 0
|
||||||
|
|
||||||
|
var inputColumnIndex = 0
|
||||||
vv.foreach {
|
vv.foreach {
|
||||||
case v: Double =>
|
case v: Double =>
|
||||||
if (v != 0.0) {
|
if (v.isNaN && !keepInvalid) {
|
||||||
indices += cur
|
throw new SparkException(
|
||||||
|
s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
|
||||||
|
|removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
|
||||||
|
.stripMargin)
|
||||||
|
} else if (v != 0.0) {
|
||||||
|
indices += featureIndex
|
||||||
values += v
|
values += v
|
||||||
}
|
}
|
||||||
cur += 1
|
inputColumnIndex += 1
|
||||||
|
featureIndex += 1
|
||||||
case vec: Vector =>
|
case vec: Vector =>
|
||||||
vec.foreachActive { case (i, v) =>
|
vec.foreachActive { case (i, v) =>
|
||||||
if (v != 0.0) {
|
if (v != 0.0) {
|
||||||
indices += cur + i
|
indices += featureIndex + i
|
||||||
values += v
|
values += v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cur += vec.size
|
inputColumnIndex += 1
|
||||||
|
featureIndex += vec.size
|
||||||
case null =>
|
case null =>
|
||||||
// TODO: output Double.NaN?
|
if (keepInvalid) {
|
||||||
throw new SparkException("Values to assemble cannot be null.")
|
val length: Int = lengths(inputColumnIndex)
|
||||||
|
Array.range(0, length).foreach { i =>
|
||||||
|
indices += featureIndex + i
|
||||||
|
values += Double.NaN
|
||||||
|
}
|
||||||
|
inputColumnIndex += 1
|
||||||
|
featureIndex += length
|
||||||
|
} else {
|
||||||
|
throw new SparkException(
|
||||||
|
s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
|
||||||
|
|removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
|
||||||
|
.stripMargin)
|
||||||
|
}
|
||||||
case o =>
|
case o =>
|
||||||
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
|
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
|
||||||
}
|
}
|
||||||
Vectors.sparse(cur, indices.result(), values.result()).compressed
|
Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,12 +18,12 @@
|
||||||
package org.apache.spark.ml.feature
|
package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import org.apache.spark.{SparkException, SparkFunSuite}
|
import org.apache.spark.{SparkException, SparkFunSuite}
|
||||||
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
|
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute}
|
||||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||||
import org.apache.spark.ml.param.ParamsSuite
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.Row
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.{col, udf}
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
|
|
||||||
class VectorAssemblerSuite
|
class VectorAssemblerSuite
|
||||||
|
@ -31,30 +31,49 @@ class VectorAssemblerSuite
|
||||||
|
|
||||||
import testImplicits._
|
import testImplicits._
|
||||||
|
|
||||||
|
@transient var dfWithNullsAndNaNs: Dataset[_] = _
|
||||||
|
|
||||||
|
override def beforeAll(): Unit = {
|
||||||
|
super.beforeAll()
|
||||||
|
val sv = Vectors.sparse(2, Array(1), Array(3.0))
|
||||||
|
dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)](
|
||||||
|
(1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
|
||||||
|
(2, 1, 0.0, null, "a", sv, 6L, null),
|
||||||
|
(3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null),
|
||||||
|
(4, 4, null, null, "a", sv, 9L, null),
|
||||||
|
(5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
|
||||||
|
(6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null))
|
||||||
|
.toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls")
|
||||||
|
}
|
||||||
|
|
||||||
test("params") {
|
test("params") {
|
||||||
ParamsSuite.checkParams(new VectorAssembler)
|
ParamsSuite.checkParams(new VectorAssembler)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("assemble") {
|
test("assemble") {
|
||||||
import org.apache.spark.ml.feature.VectorAssembler.assemble
|
import org.apache.spark.ml.feature.VectorAssembler.assemble
|
||||||
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
|
assert(assemble(Array(1), keepInvalid = true)(0.0)
|
||||||
assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
|
=== Vectors.sparse(1, Array.empty, Array.empty))
|
||||||
|
assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0)
|
||||||
|
=== Vectors.sparse(2, Array(1), Array(1.0)))
|
||||||
val dv = Vectors.dense(2.0, 0.0)
|
val dv = Vectors.dense(2.0, 0.0)
|
||||||
assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
|
assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) ===
|
||||||
|
Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
|
||||||
val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
|
val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
|
||||||
assert(assemble(0.0, dv, 1.0, sv) ===
|
assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) ===
|
||||||
Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
|
Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
|
||||||
for (v <- Seq(1, "a", null)) {
|
for (v <- Seq(1, "a")) {
|
||||||
intercept[SparkException](assemble(v))
|
intercept[SparkException](assemble(Array(1), keepInvalid = true)(v))
|
||||||
intercept[SparkException](assemble(1.0, v))
|
intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("assemble should compress vectors") {
|
test("assemble should compress vectors") {
|
||||||
import org.apache.spark.ml.feature.VectorAssembler.assemble
|
import org.apache.spark.ml.feature.VectorAssembler.assemble
|
||||||
val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
|
val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0))
|
||||||
assert(v1.isInstanceOf[SparseVector])
|
assert(v1.isInstanceOf[SparseVector])
|
||||||
val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
|
val sv = Vectors.sparse(1, Array(0), Array(4.0))
|
||||||
|
val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv)
|
||||||
assert(v2.isInstanceOf[DenseVector])
|
assert(v2.isInstanceOf[DenseVector])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,4 +166,94 @@ class VectorAssemblerSuite
|
||||||
.filter(vectorUDF($"features") > 1)
|
.filter(vectorUDF($"features") > 1)
|
||||||
.count() == 1)
|
.count() == 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("assemble should keep nulls when keepInvalid is true") {
|
||||||
|
import org.apache.spark.ml.feature.VectorAssembler.assemble
|
||||||
|
assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN))
|
||||||
|
assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null)
|
||||||
|
=== Vectors.dense(1.0, Double.NaN, Double.NaN))
|
||||||
|
assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN))
|
||||||
|
assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("assemble should throw errors when keepInvalid is false") {
|
||||||
|
import org.apache.spark.ml.feature.VectorAssembler.assemble
|
||||||
|
intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null))
|
||||||
|
intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null))
|
||||||
|
intercept[SparkException](assemble(Array(1), keepInvalid = false)(null))
|
||||||
|
intercept[SparkException](assemble(Array(2), keepInvalid = false)(null))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("get lengths functions") {
|
||||||
|
import org.apache.spark.ml.feature.VectorAssembler._
|
||||||
|
val df = dfWithNullsAndNaNs
|
||||||
|
assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2))
|
||||||
|
assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y")))
|
||||||
|
.getMessage.contains("VectorSizeHint"))
|
||||||
|
assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"),
|
||||||
|
Seq("y"))).getMessage.contains("VectorSizeHint"))
|
||||||
|
|
||||||
|
assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2))
|
||||||
|
assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID))
|
||||||
|
.getMessage.contains("VectorSizeHint"))
|
||||||
|
assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID))
|
||||||
|
.getMessage.contains("VectorSizeHint"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Handle Invalid should behave properly") {
|
||||||
|
val assembler = new VectorAssembler()
|
||||||
|
.setInputCols(Array("x", "y", "z", "n"))
|
||||||
|
.setOutputCol("features")
|
||||||
|
|
||||||
|
def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = {
|
||||||
|
val attributeY = new AttributeGroup("y", 2)
|
||||||
|
val attributeZ = new AttributeGroup(
|
||||||
|
"z",
|
||||||
|
Array[Attribute](
|
||||||
|
NumericAttribute.defaultAttr.withName("foo"),
|
||||||
|
NumericAttribute.defaultAttr.withName("bar")))
|
||||||
|
val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata())
|
||||||
|
.withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter)
|
||||||
|
val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata)
|
||||||
|
output.collect()
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
def runWithFirstRow(mode: String): Dataset[_] = {
|
||||||
|
val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs)
|
||||||
|
output.collect()
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
def runWithAllNullVectors(mode: String): Dataset[_] = {
|
||||||
|
val output = assembler.setHandleInvalid(mode)
|
||||||
|
.transform(dfWithNullsAndNaNs.filter("0 == id1 % 2"))
|
||||||
|
output.collect()
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
// behavior when vector size hint is given
|
||||||
|
assert(runWithMetadata("keep").count() == 6, "should keep all rows")
|
||||||
|
assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls")
|
||||||
|
// should throw error with nulls
|
||||||
|
intercept[SparkException](runWithMetadata("error"))
|
||||||
|
// should throw error with NaNs
|
||||||
|
intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4"))
|
||||||
|
|
||||||
|
// behavior when first row has information
|
||||||
|
assert(intercept[RuntimeException](runWithFirstRow("keep").count())
|
||||||
|
.getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
|
||||||
|
assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls")
|
||||||
|
intercept[SparkException](runWithFirstRow("error"))
|
||||||
|
|
||||||
|
// behavior when vector column is all null
|
||||||
|
assert(intercept[RuntimeException](runWithAllNullVectors("skip"))
|
||||||
|
.getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
|
||||||
|
assert(intercept[NullPointerException](runWithAllNullVectors("error"))
|
||||||
|
.getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
|
||||||
|
|
||||||
|
// behavior when scalar column is all null
|
||||||
|
assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue