[SPARK-23870][ML] Forward RFormula handleInvalid Param to VectorAssembler to handle invalid values in non-string columns

## What changes were proposed in this pull request?

`handleInvalid` Param was forwarded to the VectorAssembler used by RFormula.

## How was this patch tested?

added a test and ran all tests for RFormula and VectorAssembler

Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>

Closes #20970 from yogeshg/spark_23562.
This commit is contained in:
Yogesh Garg 2018-04-05 19:55:42 -07:00 committed by Joseph K. Bradley
parent 4807d381bb
commit f2ac087956
2 changed files with 24 additions and 0 deletions

View file

@ -278,6 +278,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
encoderStages += new VectorAssembler(uid)
.setInputCols(encodedTerms.toArray)
.setOutputCol($(featuresCol))
.setHandleInvalid($(handleInvalid))
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
encoderStages += new ColumnPruner(tempColumns.toSet)

View file

@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkException
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
@ -592,4 +593,26 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
assert(features.toArray === a +: b.toArray)
}
}
test("SPARK-23562 RFormula handleInvalid should handle invalid values in non-string columns.") {
val d1 = Seq(
(1001L, "a"),
(1002L, "b")).toDF("id1", "c1")
val d2 = Seq[(java.lang.Long, String)](
(20001L, "x"),
(20002L, "y"),
(null, null)).toDF("id2", "c2")
val dataset = d1.crossJoin(d2)
def get_output(mode: String): DataFrame = {
val formula = new RFormula().setFormula("c1 ~ id2").setHandleInvalid(mode)
formula.fit(dataset).transform(dataset).select("features", "label")
}
assert(intercept[SparkException](get_output("error").collect())
.getMessage.contains("Encountered null while assembling a row"))
assert(get_output("skip").count() == 4)
assert(get_output("keep").count() == 6)
}
}