[SPARK-23870][ML] Forward RFormula handleInvalid Param to VectorAssembler to handle invalid values in non-string columns
## What changes were proposed in this pull request? `handleInvalid` Param was forwarded to the VectorAssembler used by RFormula. ## How was this patch tested? added a test and ran all tests for RFormula and VectorAssembler Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com> Closes #20970 from yogeshg/spark_23562.
This commit is contained in:
parent
4807d381bb
commit
f2ac087956
|
@ -278,6 +278,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
|
|||
encoderStages += new VectorAssembler(uid)
|
||||
.setInputCols(encodedTerms.toArray)
|
||||
.setOutputCol($(featuresCol))
|
||||
.setHandleInvalid($(handleInvalid))
|
||||
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
|
||||
encoderStages += new ColumnPruner(tempColumns.toSet)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.ml.attribute._
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
|
@ -592,4 +593,26 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
|
|||
assert(features.toArray === a +: b.toArray)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-23562 RFormula handleInvalid should handle invalid values in non-string columns.") {
|
||||
val d1 = Seq(
|
||||
(1001L, "a"),
|
||||
(1002L, "b")).toDF("id1", "c1")
|
||||
val d2 = Seq[(java.lang.Long, String)](
|
||||
(20001L, "x"),
|
||||
(20002L, "y"),
|
||||
(null, null)).toDF("id2", "c2")
|
||||
val dataset = d1.crossJoin(d2)
|
||||
|
||||
def get_output(mode: String): DataFrame = {
|
||||
val formula = new RFormula().setFormula("c1 ~ id2").setHandleInvalid(mode)
|
||||
formula.fit(dataset).transform(dataset).select("features", "label")
|
||||
}
|
||||
|
||||
assert(intercept[SparkException](get_output("error").collect())
|
||||
.getMessage.contains("Encountered null while assembling a row"))
|
||||
assert(get_output("skip").count() == 4)
|
||||
assert(get_output("keep").count() == 6)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue