[SPARK-29808][ML][PYTHON] StopWordsRemover should support multi-cols
### What changes were proposed in this pull request? Add multi-cols support in StopWordsRemover ### Why are the changes needed? As a basic Transformer, StopWordsRemover should support multi-cols. Param stopWords can be applied across all columns. ### Does this PR introduce any user-facing change? ```StopWordsRemover.setInputCols``` ```StopWordsRemover.setOutputCols``` ### How was this patch tested? Unit tests Closes #26480 from huaxingao/spark-29808. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
8c2bf64743
commit
1f4075d29e
|
@ -22,15 +22,19 @@ import java.util.Locale
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.Transformer
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
import org.apache.spark.sql.functions.{col, udf}
|
||||
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
|
||||
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
|
||||
|
||||
/**
|
||||
* A feature transformer that filters out stop words from input.
|
||||
*
|
||||
* Since 3.0.0, `StopWordsRemover` can filter out multiple columns at once by setting the
|
||||
* `inputCols` parameter. Note that when both the `inputCol` and `inputCols` parameters are set,
|
||||
* an Exception will be thrown.
|
||||
*
|
||||
* @note null values from input array are preserved unless adding null to stopWords
|
||||
* explicitly.
|
||||
*
|
||||
|
@ -38,7 +42,8 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
|
|||
*/
|
||||
@Since("1.5.0")
|
||||
class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String)
|
||||
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
|
||||
extends Transformer with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols
|
||||
with DefaultParamsWritable {
|
||||
|
||||
@Since("1.5.0")
|
||||
def this() = this(Identifiable.randomUID("stopWords"))
|
||||
|
@ -51,6 +56,14 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
|
|||
@Since("1.5.0")
|
||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
|
||||
|
||||
/**
|
||||
* The words to be filtered out.
|
||||
* Default: English stop words
|
||||
|
@ -121,6 +134,15 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
|
|||
}
|
||||
}
|
||||
|
||||
/** Returns the input and output column names corresponding in pair. */
|
||||
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
|
||||
if (isSet(inputCol)) {
|
||||
(Array($(inputCol)), Array($(outputCol)))
|
||||
} else {
|
||||
($(inputCols), $(outputCols))
|
||||
}
|
||||
}
|
||||
|
||||
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
|
||||
caseSensitive -> false, locale -> getDefaultOrUS.toString)
|
||||
|
||||
|
@ -142,16 +164,38 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
|
|||
terms.filter(s => !lowerStopWords.contains(toLower(s)))
|
||||
}
|
||||
}
|
||||
val metadata = outputSchema($(outputCol)).metadata
|
||||
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
|
||||
|
||||
val (inputColNames, outputColNames) = getInOutCols()
|
||||
val ouputCols = inputColNames.map { inputColName =>
|
||||
t(col(inputColName))
|
||||
}
|
||||
val ouputMetadata = outputColNames.map(outputSchema(_).metadata)
|
||||
dataset.withColumns(outputColNames, ouputCols, ouputMetadata)
|
||||
}
|
||||
|
||||
@Since("1.5.0")
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
val inputType = schema($(inputCol)).dataType
|
||||
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol),
|
||||
Seq(outputCols))
|
||||
|
||||
if (isSet(inputCols)) {
|
||||
require(getInputCols.length == getOutputCols.length,
|
||||
s"StopWordsRemover $this has mismatched Params " +
|
||||
s"for multi-column transform. Params ($inputCols, $outputCols) should have " +
|
||||
"equal lengths, but they have different lengths: " +
|
||||
s"(${getInputCols.length}, ${getOutputCols.length}).")
|
||||
}
|
||||
|
||||
val (inputColNames, outputColNames) = getInOutCols()
|
||||
val newCols = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
|
||||
require(!schema.fieldNames.contains(outputColName),
|
||||
s"Output Column $outputColName already exists.")
|
||||
val inputType = schema(inputColName).dataType
|
||||
require(inputType.sameType(ArrayType(StringType)), "Input type must be " +
|
||||
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
|
||||
SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
|
||||
StructField(outputColName, inputType, schema(inputColName).nullable)
|
||||
}
|
||||
StructType(schema.fields ++ newCols)
|
||||
}
|
||||
|
||||
@Since("1.5.0")
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
|
|||
|
||||
import java.util.Locale
|
||||
|
||||
import org.apache.spark.ml.Pipeline
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
|
||||
|
@ -181,12 +182,19 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
|
|||
}
|
||||
|
||||
test("read/write") {
|
||||
val t = new StopWordsRemover()
|
||||
val t1 = new StopWordsRemover()
|
||||
.setInputCol("myInputCol")
|
||||
.setOutputCol("myOutputCol")
|
||||
.setStopWords(Array("the", "a"))
|
||||
.setCaseSensitive(true)
|
||||
testDefaultReadWrite(t)
|
||||
testDefaultReadWrite(t1)
|
||||
|
||||
val t2 = new StopWordsRemover()
|
||||
.setInputCols(Array("input1", "input2", "input3"))
|
||||
.setOutputCols(Array("result1", "result2", "result3"))
|
||||
.setStopWords(Array("the", "a"))
|
||||
.setCaseSensitive(true)
|
||||
testDefaultReadWrite(t2)
|
||||
}
|
||||
|
||||
test("StopWordsRemover output column already exists") {
|
||||
|
@ -199,7 +207,7 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
|
|||
testTransformerByInterceptingException[(Array[String], Array[String])](
|
||||
dataSet,
|
||||
remover,
|
||||
s"requirement failed: Column $outputCol already exists.",
|
||||
s"requirement failed: Output Column $outputCol already exists.",
|
||||
"expected")
|
||||
}
|
||||
|
||||
|
@ -217,4 +225,123 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
|
|||
Locale.setDefault(oldDefault)
|
||||
}
|
||||
}
|
||||
|
||||
test("Multiple Columns: StopWordsRemover default") {
|
||||
val remover = new StopWordsRemover()
|
||||
.setInputCols(Array("raw1", "raw2"))
|
||||
.setOutputCols(Array("filtered1", "filtered2"))
|
||||
val df = Seq(
|
||||
(Seq("test", "test"), Seq("test1", "test2"), Seq("test", "test"), Seq("test1", "test2")),
|
||||
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")),
|
||||
(Seq("a", "the", "an"), Seq("the", "an"), Seq(), Seq()),
|
||||
(Seq("A", "The", "AN"), Seq("A", "The"), Seq(), Seq()),
|
||||
(Seq(null), Seq(null), Seq(null), Seq(null)),
|
||||
(Seq(), Seq(), Seq(), Seq())
|
||||
).toDF("raw1", "raw2", "expected1", "expected2")
|
||||
|
||||
remover.transform(df)
|
||||
.select("filtered1", "expected1", "filtered2", "expected2")
|
||||
.collect().foreach {
|
||||
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) =>
|
||||
assert(r1 === e1,
|
||||
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
|
||||
assert(r2 === e2,
|
||||
s"The result value is not correct after bucketing. Expected $e2 but found $r2")
|
||||
}
|
||||
}
|
||||
|
||||
test("Multiple Columns: StopWordsRemover with particular stop words list") {
|
||||
val stopWords = Array("test", "a", "an", "the")
|
||||
val remover = new StopWordsRemover()
|
||||
.setInputCols(Array("raw1", "raw2"))
|
||||
.setOutputCols(Array("filtered1", "filtered2"))
|
||||
.setStopWords(stopWords)
|
||||
val df = Seq(
|
||||
(Seq("test", "test"), Seq("test1", "test2"), Seq(), Seq("test1", "test2")),
|
||||
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")),
|
||||
(Seq("a", "the", "an"), Seq("a", "the", "test1"), Seq(), Seq("test1")),
|
||||
(Seq("A", "The", "AN"), Seq("A", "The", "AN"), Seq(), Seq()),
|
||||
(Seq(null), Seq(null), Seq(null), Seq(null)),
|
||||
(Seq(), Seq(), Seq(), Seq())
|
||||
).toDF("raw1", "raw2", "expected1", "expected2")
|
||||
|
||||
remover.transform(df)
|
||||
.select("filtered1", "expected1", "filtered2", "expected2")
|
||||
.collect().foreach {
|
||||
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) =>
|
||||
assert(r1 === e1,
|
||||
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
|
||||
assert(r2 === e2,
|
||||
s"The result value is not correct after bucketing. Expected $e2 but found $r2")
|
||||
}
|
||||
}
|
||||
|
||||
test("Compare single/multiple column(s) StopWordsRemover in pipeline") {
|
||||
val df = Seq(
|
||||
(Seq("test", "test"), Seq("test1", "test2")),
|
||||
(Seq("a", "b", "c", "d"), Seq("a", "b")),
|
||||
(Seq("a", "the", "an"), Seq("a", "the", "test1")),
|
||||
(Seq("A", "The", "AN"), Seq("A", "The", "AN")),
|
||||
(Seq(null), Seq(null)),
|
||||
(Seq(), Seq())
|
||||
).toDF("input1", "input2")
|
||||
|
||||
val multiColsRemover = new StopWordsRemover()
|
||||
.setInputCols(Array("input1", "input2"))
|
||||
.setOutputCols(Array("output1", "output2"))
|
||||
|
||||
val plForMultiCols = new Pipeline()
|
||||
.setStages(Array(multiColsRemover))
|
||||
.fit(df)
|
||||
|
||||
val removerForCol1 = new StopWordsRemover()
|
||||
.setInputCol("input1")
|
||||
.setOutputCol("output1")
|
||||
val removerForCol2 = new StopWordsRemover()
|
||||
.setInputCol("input2")
|
||||
.setOutputCol("output2")
|
||||
|
||||
val plForSingleCol = new Pipeline()
|
||||
.setStages(Array(removerForCol1, removerForCol2))
|
||||
.fit(df)
|
||||
|
||||
val resultForSingleCol = plForSingleCol.transform(df)
|
||||
.select("output1", "output2")
|
||||
.collect()
|
||||
val resultForMultiCols = plForMultiCols.transform(df)
|
||||
.select("output1", "output2")
|
||||
.collect()
|
||||
|
||||
resultForSingleCol.zip(resultForMultiCols).foreach {
|
||||
case (rowForSingle, rowForMultiCols) =>
|
||||
assert(rowForSingle === rowForMultiCols)
|
||||
}
|
||||
}
|
||||
|
||||
test("Multiple Columns: Mismatched sizes of inputCols/outputCols") {
|
||||
val remover = new StopWordsRemover()
|
||||
.setInputCols(Array("input1"))
|
||||
.setOutputCols(Array("result1", "result2"))
|
||||
val df = Seq(
|
||||
(Seq("A"), Seq("A")),
|
||||
(Seq("The", "the"), Seq("The"))
|
||||
).toDF("input1", "input2")
|
||||
intercept[IllegalArgumentException] {
|
||||
remover.transform(df).count()
|
||||
}
|
||||
}
|
||||
|
||||
test("Multiple Columns: Set both of inputCol/inputCols") {
|
||||
val remover = new StopWordsRemover()
|
||||
.setInputCols(Array("input1", "input2"))
|
||||
.setOutputCols(Array("result1", "result2"))
|
||||
.setInputCol("input1")
|
||||
val df = Seq(
|
||||
(Seq("A"), Seq("A")),
|
||||
(Seq("The", "the"), Seq("The"))
|
||||
).toDF("input1", "input2")
|
||||
intercept[IllegalArgumentException] {
|
||||
remover.transform(df).count()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3774,9 +3774,13 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
|
|||
return self._set(outputCol=value)
|
||||
|
||||
|
||||
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
|
||||
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
|
||||
JavaMLReadable, JavaMLWritable):
|
||||
"""
|
||||
A feature transformer that filters out stop words from input.
|
||||
Since 3.0.0, :py:class:`StopWordsRemover` can filter out multiple columns at once by setting
|
||||
the :py:attr:`inputCols` parameter. Note that when both the :py:attr:`inputCol` and
|
||||
:py:attr:`inputCols` parameters are set, an Exception will be thrown.
|
||||
|
||||
.. note:: null values from input array are preserved unless adding null to stopWords explicitly.
|
||||
|
||||
|
@ -3795,6 +3799,17 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
|
|||
True
|
||||
>>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive()
|
||||
True
|
||||
>>> df2 = spark.createDataFrame([(["a", "b", "c"], ["a", "b"])], ["text1", "text2"])
|
||||
>>> remover2 = StopWordsRemover(stopWords=["b"])
|
||||
>>> remover2.setInputCols(["text1", "text2"]).setOutputCols(["words1", "words2"])
|
||||
StopWordsRemover...
|
||||
>>> remover2.transform(df2).show()
|
||||
+---------+------+------+------+
|
||||
| text1| text2|words1|words2|
|
||||
+---------+------+------+------+
|
||||
|[a, b, c]|[a, b]|[a, c]| [a]|
|
||||
+---------+------+------+------+
|
||||
...
|
||||
|
||||
.. versionadded:: 1.6.0
|
||||
"""
|
||||
|
@ -3808,10 +3823,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
|
|||
|
||||
@keyword_only
|
||||
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
|
||||
locale=None):
|
||||
locale=None, inputCols=None, outputCols=None):
|
||||
"""
|
||||
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
|
||||
locale=None)
|
||||
locale=None, inputCols=None, outputCols=None)
|
||||
"""
|
||||
super(StopWordsRemover, self).__init__()
|
||||
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
|
||||
|
@ -3824,10 +3839,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
|
|||
@keyword_only
|
||||
@since("1.6.0")
|
||||
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
|
||||
locale=None):
|
||||
locale=None, inputCols=None, outputCols=None):
|
||||
"""
|
||||
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
|
||||
locale=None)
|
||||
locale=None, inputCols=None, outputCols=None)
|
||||
Sets params for this StopWordRemover.
|
||||
"""
|
||||
kwargs = self._input_kwargs
|
||||
|
@ -3887,6 +3902,20 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
|
|||
"""
|
||||
return self._set(outputCol=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setInputCols(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`inputCols`.
|
||||
"""
|
||||
return self._set(inputCols=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setOutputCols(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`outputCols`.
|
||||
"""
|
||||
return self._set(outputCols=value)
|
||||
|
||||
@staticmethod
|
||||
@since("2.0.0")
|
||||
def loadDefaultStopWords(language):
|
||||
|
|
Loading…
Reference in a new issue