[SPARK-29808][ML][PYTHON] StopWordsRemover should support multi-cols

### What changes were proposed in this pull request?
Add multi-cols support in StopWordsRemover

### Why are the changes needed?
As a basic Transformer, StopWordsRemover should support multi-cols.
Param stopWords can be applied across all columns.

### Does this PR introduce any user-facing change?
```StopWordsRemover.setInputCols```
```StopWordsRemover.setOutputCols```

### How was this patch tested?
Unit tests

Closes #26480 from huaxingao/spark-29808.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
Huaxin Gao 2019-11-13 08:18:23 -06:00 committed by Sean Owen
parent 8c2bf64743
commit 1f4075d29e
3 changed files with 217 additions and 17 deletions

View file

@ -22,15 +22,19 @@ import java.util.Locale
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StringType, StructType} import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
/** /**
* A feature transformer that filters out stop words from input. * A feature transformer that filters out stop words from input.
* *
* Since 3.0.0, `StopWordsRemover` can filter out multiple columns at once by setting the
* `inputCols` parameter. Note that when both the `inputCol` and `inputCols` parameters are set,
* an Exception will be thrown.
*
* @note null values from input array are preserved unless adding null to stopWords * @note null values from input array are preserved unless adding null to stopWords
* explicitly. * explicitly.
* *
@ -38,7 +42,8 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
*/ */
@Since("1.5.0") @Since("1.5.0")
class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String) class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { extends Transformer with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols
with DefaultParamsWritable {
@Since("1.5.0") @Since("1.5.0")
def this() = this(Identifiable.randomUID("stopWords")) def this() = this(Identifiable.randomUID("stopWords"))
@ -51,6 +56,14 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
@Since("1.5.0") @Since("1.5.0")
def setOutputCol(value: String): this.type = set(outputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value)
/** @group setParam */
@Since("3.0.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
/** @group setParam */
@Since("3.0.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
/** /**
* The words to be filtered out. * The words to be filtered out.
* Default: English stop words * Default: English stop words
@ -121,6 +134,15 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
} }
} }
/** Returns the input and output column names corresponding in pair. */
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
if (isSet(inputCol)) {
(Array($(inputCol)), Array($(outputCol)))
} else {
($(inputCols), $(outputCols))
}
}
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive -> false, locale -> getDefaultOrUS.toString) caseSensitive -> false, locale -> getDefaultOrUS.toString)
@ -142,16 +164,38 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
terms.filter(s => !lowerStopWords.contains(toLower(s))) terms.filter(s => !lowerStopWords.contains(toLower(s)))
} }
} }
val metadata = outputSchema($(outputCol)).metadata
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) val (inputColNames, outputColNames) = getInOutCols()
val ouputCols = inputColNames.map { inputColName =>
t(col(inputColName))
}
val ouputMetadata = outputColNames.map(outputSchema(_).metadata)
dataset.withColumns(outputColNames, ouputCols, ouputMetadata)
} }
@Since("1.5.0") @Since("1.5.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol),
require(inputType.sameType(ArrayType(StringType)), "Input type must be " + Seq(outputCols))
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) if (isSet(inputCols)) {
require(getInputCols.length == getOutputCols.length,
s"StopWordsRemover $this has mismatched Params " +
s"for multi-column transform. Params ($inputCols, $outputCols) should have " +
"equal lengths, but they have different lengths: " +
s"(${getInputCols.length}, ${getOutputCols.length}).")
}
val (inputColNames, outputColNames) = getInOutCols()
val newCols = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
require(!schema.fieldNames.contains(outputColName),
s"Output Column $outputColName already exists.")
val inputType = schema(inputColName).dataType
require(inputType.sameType(ArrayType(StringType)), "Input type must be " +
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
StructField(outputColName, inputType, schema(inputColName).nullable)
}
StructType(schema.fields ++ newCols)
} }
@Since("1.5.0") @Since("1.5.0")

View file

@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import java.util.Locale import java.util.Locale
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row}
@ -181,12 +182,19 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
} }
test("read/write") { test("read/write") {
val t = new StopWordsRemover() val t1 = new StopWordsRemover()
.setInputCol("myInputCol") .setInputCol("myInputCol")
.setOutputCol("myOutputCol") .setOutputCol("myOutputCol")
.setStopWords(Array("the", "a")) .setStopWords(Array("the", "a"))
.setCaseSensitive(true) .setCaseSensitive(true)
testDefaultReadWrite(t) testDefaultReadWrite(t1)
val t2 = new StopWordsRemover()
.setInputCols(Array("input1", "input2", "input3"))
.setOutputCols(Array("result1", "result2", "result3"))
.setStopWords(Array("the", "a"))
.setCaseSensitive(true)
testDefaultReadWrite(t2)
} }
test("StopWordsRemover output column already exists") { test("StopWordsRemover output column already exists") {
@ -199,7 +207,7 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
testTransformerByInterceptingException[(Array[String], Array[String])]( testTransformerByInterceptingException[(Array[String], Array[String])](
dataSet, dataSet,
remover, remover,
s"requirement failed: Column $outputCol already exists.", s"requirement failed: Output Column $outputCol already exists.",
"expected") "expected")
} }
@ -217,4 +225,123 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
Locale.setDefault(oldDefault) Locale.setDefault(oldDefault)
} }
} }
test("Multiple Columns: StopWordsRemover default") {
val remover = new StopWordsRemover()
.setInputCols(Array("raw1", "raw2"))
.setOutputCols(Array("filtered1", "filtered2"))
val df = Seq(
(Seq("test", "test"), Seq("test1", "test2"), Seq("test", "test"), Seq("test1", "test2")),
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")),
(Seq("a", "the", "an"), Seq("the", "an"), Seq(), Seq()),
(Seq("A", "The", "AN"), Seq("A", "The"), Seq(), Seq()),
(Seq(null), Seq(null), Seq(null), Seq(null)),
(Seq(), Seq(), Seq(), Seq())
).toDF("raw1", "raw2", "expected1", "expected2")
remover.transform(df)
.select("filtered1", "expected1", "filtered2", "expected2")
.collect().foreach {
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) =>
assert(r1 === e1,
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
assert(r2 === e2,
s"The result value is not correct after bucketing. Expected $e2 but found $r2")
}
}
test("Multiple Columns: StopWordsRemover with particular stop words list") {
val stopWords = Array("test", "a", "an", "the")
val remover = new StopWordsRemover()
.setInputCols(Array("raw1", "raw2"))
.setOutputCols(Array("filtered1", "filtered2"))
.setStopWords(stopWords)
val df = Seq(
(Seq("test", "test"), Seq("test1", "test2"), Seq(), Seq("test1", "test2")),
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")),
(Seq("a", "the", "an"), Seq("a", "the", "test1"), Seq(), Seq("test1")),
(Seq("A", "The", "AN"), Seq("A", "The", "AN"), Seq(), Seq()),
(Seq(null), Seq(null), Seq(null), Seq(null)),
(Seq(), Seq(), Seq(), Seq())
).toDF("raw1", "raw2", "expected1", "expected2")
remover.transform(df)
.select("filtered1", "expected1", "filtered2", "expected2")
.collect().foreach {
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) =>
assert(r1 === e1,
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
assert(r2 === e2,
s"The result value is not correct after bucketing. Expected $e2 but found $r2")
}
}
test("Compare single/multiple column(s) StopWordsRemover in pipeline") {
val df = Seq(
(Seq("test", "test"), Seq("test1", "test2")),
(Seq("a", "b", "c", "d"), Seq("a", "b")),
(Seq("a", "the", "an"), Seq("a", "the", "test1")),
(Seq("A", "The", "AN"), Seq("A", "The", "AN")),
(Seq(null), Seq(null)),
(Seq(), Seq())
).toDF("input1", "input2")
val multiColsRemover = new StopWordsRemover()
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("output1", "output2"))
val plForMultiCols = new Pipeline()
.setStages(Array(multiColsRemover))
.fit(df)
val removerForCol1 = new StopWordsRemover()
.setInputCol("input1")
.setOutputCol("output1")
val removerForCol2 = new StopWordsRemover()
.setInputCol("input2")
.setOutputCol("output2")
val plForSingleCol = new Pipeline()
.setStages(Array(removerForCol1, removerForCol2))
.fit(df)
val resultForSingleCol = plForSingleCol.transform(df)
.select("output1", "output2")
.collect()
val resultForMultiCols = plForMultiCols.transform(df)
.select("output1", "output2")
.collect()
resultForSingleCol.zip(resultForMultiCols).foreach {
case (rowForSingle, rowForMultiCols) =>
assert(rowForSingle === rowForMultiCols)
}
}
test("Multiple Columns: Mismatched sizes of inputCols/outputCols") {
val remover = new StopWordsRemover()
.setInputCols(Array("input1"))
.setOutputCols(Array("result1", "result2"))
val df = Seq(
(Seq("A"), Seq("A")),
(Seq("The", "the"), Seq("The"))
).toDF("input1", "input2")
intercept[IllegalArgumentException] {
remover.transform(df).count()
}
}
test("Multiple Columns: Set both of inputCol/inputCols") {
val remover = new StopWordsRemover()
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("result1", "result2"))
.setInputCol("input1")
val df = Seq(
(Seq("A"), Seq("A")),
(Seq("The", "the"), Seq("The"))
).toDF("input1", "input2")
intercept[IllegalArgumentException] {
remover.transform(df).count()
}
}
} }

View file

@ -3774,9 +3774,13 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
return self._set(outputCol=value) return self._set(outputCol=value)
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
JavaMLReadable, JavaMLWritable):
""" """
A feature transformer that filters out stop words from input. A feature transformer that filters out stop words from input.
Since 3.0.0, :py:class:`StopWordsRemover` can filter out multiple columns at once by setting
the :py:attr:`inputCols` parameter. Note that when both the :py:attr:`inputCol` and
:py:attr:`inputCols` parameters are set, an Exception will be thrown.
.. note:: null values from input array are preserved unless adding null to stopWords explicitly. .. note:: null values from input array are preserved unless adding null to stopWords explicitly.
@ -3795,6 +3799,17 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
True True
>>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive() >>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive()
True True
>>> df2 = spark.createDataFrame([(["a", "b", "c"], ["a", "b"])], ["text1", "text2"])
>>> remover2 = StopWordsRemover(stopWords=["b"])
>>> remover2.setInputCols(["text1", "text2"]).setOutputCols(["words1", "words2"])
StopWordsRemover...
>>> remover2.transform(df2).show()
+---------+------+------+------+
| text1| text2|words1|words2|
+---------+------+------+------+
|[a, b, c]|[a, b]|[a, c]| [a]|
+---------+------+------+------+
...
.. versionadded:: 1.6.0 .. versionadded:: 1.6.0
""" """
@ -3808,10 +3823,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
@keyword_only @keyword_only
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
locale=None): locale=None, inputCols=None, outputCols=None):
""" """
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
locale=None) locale=None, inputCols=None, outputCols=None)
""" """
super(StopWordsRemover, self).__init__() super(StopWordsRemover, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
@ -3824,10 +3839,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
@keyword_only @keyword_only
@since("1.6.0") @since("1.6.0")
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
locale=None): locale=None, inputCols=None, outputCols=None):
""" """
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
locale=None) locale=None, inputCols=None, outputCols=None)
Sets params for this StopWordRemover. Sets params for this StopWordRemover.
""" """
kwargs = self._input_kwargs kwargs = self._input_kwargs
@ -3887,6 +3902,20 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
""" """
return self._set(outputCol=value) return self._set(outputCol=value)
@since("3.0.0")
def setInputCols(self, value):
"""
Sets the value of :py:attr:`inputCols`.
"""
return self._set(inputCols=value)
@since("3.0.0")
def setOutputCols(self, value):
"""
Sets the value of :py:attr:`outputCols`.
"""
return self._set(outputCols=value)
@staticmethod @staticmethod
@since("2.0.0") @since("2.0.0")
def loadDefaultStopWords(language): def loadDefaultStopWords(language):