[SPARK-23690][ML] Add handleinvalid to VectorAssembler

## What changes were proposed in this pull request?

Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found.

## How was this patch tested?

Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases.

Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>
Author: Bago Amirbekian <bago@databricks.com>
Author: Yogesh Garg <1059168+yogeshg@users.noreply.github.com>

Closes #20829 from yogeshg/rformula_handleinvalid.
This commit is contained in:
Yogesh Garg 2018-04-02 16:41:26 -07:00 committed by Joseph K. Bradley
parent 28ea4e3142
commit a1351828d3
3 changed files with 284 additions and 47 deletions

View file

@ -234,7 +234,7 @@ class StringIndexerModel (
val metadata = NominalAttribute.defaultAttr val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata() .withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out. // If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match { val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case StringIndexer.SKIP_INVALID => case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String => val filterer = udf { label: String =>
labelToIndex.contains(label) labelToIndex.contains(label)

View file

@ -17,14 +17,17 @@
package org.apache.spark.ml.feature package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder import java.util.NoSuchElementException
import scala.collection.mutable
import scala.language.existentials
import org.apache.spark.SparkException import org.apache.spark.SparkException
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.{DataFrame, Dataset, Row}
@ -33,10 +36,14 @@ import org.apache.spark.sql.types._
/** /**
* A feature transformer that merges multiple columns into a vector column. * A feature transformer that merges multiple columns into a vector column.
*
* This requires one pass over the entire dataset. In case we need to infer column lengths from the
* data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
*/ */
@Since("1.4.0") @Since("1.4.0")
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
with DefaultParamsWritable {
@Since("1.4.0") @Since("1.4.0")
def this() = this(Identifiable.randomUID("vecAssembler")) def this() = this(Identifiable.randomUID("vecAssembler"))
@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0") @Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value)
/** @group setParam */
@Since("2.4.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
/**
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
* output). Column lengths are taken from the size of ML Attribute Group, which can be set using
* `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
* from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
* Default: "error"
* @group param
*/
@Since("2.4.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"""Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
|invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
|output). Column lengths are taken from the size of ML Attribute Group, which can be set using
|`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
|from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
|""".stripMargin.replaceAll("\n", " "),
ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
@Since("2.0.0") @Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = { override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
// Schema transformation. // Schema transformation.
val schema = dataset.schema val schema = dataset.schema
lazy val first = dataset.toDF.first()
val attrs = $(inputCols).flatMap { c => val vectorCols = $(inputCols).filter { c =>
schema(c).dataType match {
case _: VectorUDT => true
case _ => false
}
}
val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))
val featureAttributesMap = $(inputCols).map { c =>
val field = schema(c) val field = schema(c)
val index = schema.fieldIndex(c)
field.dataType match { field.dataType match {
case DoubleType => case DoubleType =>
val attr = Attribute.fromStructField(field) val attribute = Attribute.fromStructField(field)
// If the input column doesn't have ML attribute, assume numeric. attribute match {
if (attr == UnresolvedAttribute) { case UnresolvedAttribute =>
Some(NumericAttribute.defaultAttr.withName(c)) Seq(NumericAttribute.defaultAttr.withName(c))
} else { case _ =>
Some(attr.withName(c)) Seq(attribute.withName(c))
} }
case _: NumericType | BooleanType => case _: NumericType | BooleanType =>
// If the input column type is a compatible scalar type, assume numeric. // If the input column type is a compatible scalar type, assume numeric.
Some(NumericAttribute.defaultAttr.withName(c)) Seq(NumericAttribute.defaultAttr.withName(c))
case _: VectorUDT => case _: VectorUDT =>
val group = AttributeGroup.fromStructField(field) val attributeGroup = AttributeGroup.fromStructField(field)
if (group.attributes.isDefined) { if (attributeGroup.attributes.isDefined) {
// If attributes are defined, copy them with updated names. attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
group.attributes.get.zipWithIndex.map { case (attr, i) =>
if (attr.name.isDefined) { if (attr.name.isDefined) {
// TODO: Define a rigorous naming scheme. // TODO: Define a rigorous naming scheme.
attr.withName(c + "_" + attr.name.get) attr.withName(c + "_" + attr.name.get)
@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
} else { } else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row. // from metadata, check the first row.
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) (0 until vectorColsLengths(c)).map { i =>
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) NumericAttribute.defaultAttr.withName(c + "_" + i)
}
} }
case otherType => case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type") throw new SparkException(s"VectorAssembler does not support the $otherType type")
} }
} }
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() val featureAttributes = featureAttributesMap.flatten[Attribute].toArray
val lengths = featureAttributesMap.map(a => a.length).toArray
val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata()
val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false)
case VectorAssembler.KEEP_INVALID => (dataset, true)
case VectorAssembler.ERROR_INVALID => (dataset, false)
}
// Data transformation. // Data transformation.
val assembleFunc = udf { r: Row => val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*) VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
}.asNondeterministic() }.asNondeterministic()
val args = $(inputCols).map { c => val args = $(inputCols).map { c =>
schema(c).dataType match { schema(c).dataType match {
@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
} }
} }
dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
} }
@Since("1.4.0") @Since("1.4.0")
@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.6.0") @Since("1.6.0")
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
private[feature] val SKIP_INVALID: String = "skip"
private[feature] val ERROR_INVALID: String = "error"
private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
/**
* Infers lengths of vector columns from the first row of the dataset
* @param dataset the dataset
* @param columns name of vector columns whose lengths need to be inferred
* @return map of column names to lengths
*/
private[feature] def getVectorLengthsFromFirstRow(
dataset: Dataset[_],
columns: Seq[String]): Map[String, Int] = {
try {
val first_row = dataset.toDF().select(columns.map(col): _*).first()
columns.zip(first_row.toSeq).map {
case (c, x) => c -> x.asInstanceOf[Vector].size
}.toMap
} catch {
case e: NullPointerException => throw new NullPointerException(
s"""Encountered null value while inferring lengths from the first row. Consider using
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
.stripMargin.replaceAll("\n", " ") + e.toString)
case e: NoSuchElementException => throw new NoSuchElementException(
s"""Encountered empty dataframe while inferring lengths from the first row. Consider using
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
.stripMargin.replaceAll("\n", " ") + e.toString)
}
}
private[feature] def getLengths(
dataset: Dataset[_],
columns: Seq[String],
handleInvalid: String): Map[String, Int] = {
val groupSizes = columns.map { c =>
c -> AttributeGroup.fromStructField(dataset.schema(c)).size
}.toMap
val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
case (true, VectorAssembler.ERROR_INVALID) =>
getVectorLengthsFromFirstRow(dataset, missingColumns)
case (true, VectorAssembler.SKIP_INVALID) =>
getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
|to add metadata for columns: ${columns.mkString("[", ", ", "]")}."""
.stripMargin.replaceAll("\n", " "))
case (_, _) => Map.empty
}
groupSizes ++ firstSizes
}
@Since("1.6.0") @Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path) override def load(path: String): VectorAssembler = super.load(path)
private[feature] def assemble(vv: Any*): Vector = { /**
val indices = ArrayBuilder.make[Int] * Returns a function that has the required information to assemble each row.
val values = ArrayBuilder.make[Double] * @param lengths an array of lengths of input columns, whose size should be equal to the number
var cur = 0 * of cells in the row (vv)
* @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
* @return a udf that can be applied on each row
*/
private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
val indices = mutable.ArrayBuilder.make[Int]
val values = mutable.ArrayBuilder.make[Double]
var featureIndex = 0
var inputColumnIndex = 0
vv.foreach { vv.foreach {
case v: Double => case v: Double =>
if (v != 0.0) { if (v.isNaN && !keepInvalid) {
indices += cur throw new SparkException(
s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
|removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
.stripMargin)
} else if (v != 0.0) {
indices += featureIndex
values += v values += v
} }
cur += 1 inputColumnIndex += 1
featureIndex += 1
case vec: Vector => case vec: Vector =>
vec.foreachActive { case (i, v) => vec.foreachActive { case (i, v) =>
if (v != 0.0) { if (v != 0.0) {
indices += cur + i indices += featureIndex + i
values += v values += v
} }
} }
cur += vec.size inputColumnIndex += 1
featureIndex += vec.size
case null => case null =>
// TODO: output Double.NaN? if (keepInvalid) {
throw new SparkException("Values to assemble cannot be null.") val length: Int = lengths(inputColumnIndex)
Array.range(0, length).foreach { i =>
indices += featureIndex + i
values += Double.NaN
}
inputColumnIndex += 1
featureIndex += length
} else {
throw new SparkException(
s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
|removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
.stripMargin)
}
case o => case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
} }
Vectors.sparse(cur, indices.result(), values.result()).compressed Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
} }
} }

View file

@ -18,12 +18,12 @@
package org.apache.spark.ml.feature package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.functions.{col, udf}
class VectorAssemblerSuite class VectorAssemblerSuite
@ -31,30 +31,49 @@ class VectorAssemblerSuite
import testImplicits._ import testImplicits._
@transient var dfWithNullsAndNaNs: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
val sv = Vectors.sparse(2, Array(1), Array(3.0))
dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)](
(1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
(2, 1, 0.0, null, "a", sv, 6L, null),
(3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null),
(4, 4, null, null, "a", sv, 9L, null),
(5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
(6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null))
.toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls")
}
test("params") { test("params") {
ParamsSuite.checkParams(new VectorAssembler) ParamsSuite.checkParams(new VectorAssembler)
} }
test("assemble") { test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble import org.apache.spark.ml.feature.VectorAssembler.assemble
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) assert(assemble(Array(1), keepInvalid = true)(0.0)
assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) === Vectors.sparse(1, Array.empty, Array.empty))
assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0)
=== Vectors.sparse(2, Array(1), Array(1.0)))
val dv = Vectors.dense(2.0, 0.0) val dv = Vectors.dense(2.0, 0.0)
assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) ===
Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
assert(assemble(0.0, dv, 1.0, sv) === assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) ===
Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
for (v <- Seq(1, "a", null)) { for (v <- Seq(1, "a")) {
intercept[SparkException](assemble(v)) intercept[SparkException](assemble(Array(1), keepInvalid = true)(v))
intercept[SparkException](assemble(1.0, v)) intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v))
} }
} }
test("assemble should compress vectors") { test("assemble should compress vectors") {
import org.apache.spark.ml.feature.VectorAssembler.assemble import org.apache.spark.ml.feature.VectorAssembler.assemble
val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0))
assert(v1.isInstanceOf[SparseVector]) assert(v1.isInstanceOf[SparseVector])
val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0))) val sv = Vectors.sparse(1, Array(0), Array(4.0))
val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv)
assert(v2.isInstanceOf[DenseVector]) assert(v2.isInstanceOf[DenseVector])
} }
@ -147,4 +166,94 @@ class VectorAssemblerSuite
.filter(vectorUDF($"features") > 1) .filter(vectorUDF($"features") > 1)
.count() == 1) .count() == 1)
} }
test("assemble should keep nulls when keepInvalid is true") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN))
assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null)
=== Vectors.dense(1.0, Double.NaN, Double.NaN))
assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN))
assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN))
}
test("assemble should throw errors when keepInvalid is false") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null))
intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null))
intercept[SparkException](assemble(Array(1), keepInvalid = false)(null))
intercept[SparkException](assemble(Array(2), keepInvalid = false)(null))
}
test("get lengths functions") {
import org.apache.spark.ml.feature.VectorAssembler._
val df = dfWithNullsAndNaNs
assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2))
assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y")))
.getMessage.contains("VectorSizeHint"))
assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"),
Seq("y"))).getMessage.contains("VectorSizeHint"))
assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2))
assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID))
.getMessage.contains("VectorSizeHint"))
assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID))
.getMessage.contains("VectorSizeHint"))
}
test("Handle Invalid should behave properly") {
val assembler = new VectorAssembler()
.setInputCols(Array("x", "y", "z", "n"))
.setOutputCol("features")
def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = {
val attributeY = new AttributeGroup("y", 2)
val attributeZ = new AttributeGroup(
"z",
Array[Attribute](
NumericAttribute.defaultAttr.withName("foo"),
NumericAttribute.defaultAttr.withName("bar")))
val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata())
.withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter)
val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata)
output.collect()
output
}
def runWithFirstRow(mode: String): Dataset[_] = {
val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs)
output.collect()
output
}
def runWithAllNullVectors(mode: String): Dataset[_] = {
val output = assembler.setHandleInvalid(mode)
.transform(dfWithNullsAndNaNs.filter("0 == id1 % 2"))
output.collect()
output
}
// behavior when vector size hint is given
assert(runWithMetadata("keep").count() == 6, "should keep all rows")
assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls")
// should throw error with nulls
intercept[SparkException](runWithMetadata("error"))
// should throw error with NaNs
intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4"))
// behavior when first row has information
assert(intercept[RuntimeException](runWithFirstRow("keep").count())
.getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls")
intercept[SparkException](runWithFirstRow("error"))
// behavior when vector column is all null
assert(intercept[RuntimeException](runWithAllNullVectors("skip"))
.getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
assert(intercept[NullPointerException](runWithAllNullVectors("error"))
.getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
// behavior when scalar column is all null
assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4)
}
} }