diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 3070012266..7bc5e56aae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -233,7 +233,7 @@ object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns) case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint - |to add metadata for columns: ${columns.mkString("[", ", ", "]")}.""" + |to add metadata for columns: ${missingColumns.mkString("[", ", ", "]")}.""" .stripMargin.replaceAll("\n", " ")) case (_, _) => Map.empty } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index a4d388fd32..4957f6f1f4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -261,4 +261,15 @@ class VectorAssemblerSuite val output = vectorAssembler.transform(dfWithNullsAndNaNs) assert(output.select("a").limit(1).collect().head == Row(Vectors.sparse(0, Seq.empty))) } + + test("SPARK-31671: should give explicit error message when can not infer column lengths") { + val df = Seq( + (Vectors.dense(1.0), Vectors.dense(2.0)) + ).toDF("n1", "n2") + val hintedDf = new VectorSizeHint().setInputCol("n1").setSize(1).transform(df) + val assembler = new VectorAssembler() + .setInputCols(Array("n1", "n2")).setOutputCol("features") + assert(!intercept[RuntimeException](assembler.setHandleInvalid("keep").transform(hintedDf)) + .getMessage.contains("n1"), "should only show no vector size columns' name") + } }