diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 2b0862c60f..c4daf64dfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -75,30 +75,40 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) val schema = dataset.schema val inputType = schema($(inputCol)).dataType val td = $(threshold) - - val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } - val binarizerVector = udf { (data: Vector) => - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - - data.foreachActive { (index, value) => - if (value > td) { - indices += index - values += 1.0 - } - } - - Vectors.sparse(data.size, indices.result(), values.result()).compressed - } - val metadata = outputSchema($(outputCol)).metadata - inputType match { + val binarizerUDF = inputType match { case DoubleType => - dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) - case _: VectorUDT => - dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) + udf { in: Double => if (in > td) 1.0 else 0.0 } + + case _: VectorUDT if td >= 0 => + udf { vector: Vector => + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + vector.foreachActive { (index, value) => + if (value > td) { + indices += index + values += 1.0 + } + } + Vectors.sparse(vector.size, indices.result(), values.result()).compressed + } + + case _: VectorUDT if td < 0 => + this.logWarning(s"Binarization operations on sparse dataset with negative threshold " + + s"$td will build a dense output, so take care when applying to sparse input.") + udf { vector: Vector => + val values = Array.fill(vector.size)(1.0) + vector.foreachActive { (index, value) => + if (value <= td) { + values(index) = 0.0 + } + } + Vectors.dense(values).compressed + } } + + dataset.withColumn($(outputCol), binarizerUDF(col($(inputCol))), metadata) } @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 05d4a6ee2d..91bec50fb9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -101,6 +101,20 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { } } + test("Binarizer should support sparse vector with negative threshold") { + val data = Seq( + (Vectors.sparse(3, Array(1), Array(0.5)), Vectors.dense(Array(1.0, 1.0, 1.0))), + (Vectors.dense(Array(0.0, 0.5, 0.0)), Vectors.dense(Array(1.0, 1.0, 1.0)))) + val df = data.toDF("feature", "expected") + val binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(-0.5) + binarizer.transform(df).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") + } + } test("read/write") { val t = new Binarizer()