[SPARK-29808][ML][PYTHON] StopWordsRemover should support multi-cols
### What changes were proposed in this pull request? Add multi-cols support in StopWordsRemover ### Why are the changes needed? As a basic Transformer, StopWordsRemover should support multi-cols. Param stopWords can be applied across all columns. ### Does this PR introduce any user-facing change? ```StopWordsRemover.setInputCols``` ```StopWordsRemover.setOutputCols``` ### How was this patch tested? Unit tests Closes #26480 from huaxingao/spark-29808. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
8c2bf64743
commit
1f4075d29e
|
@ -22,15 +22,19 @@ import java.util.Locale
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.ml.Transformer
|
import org.apache.spark.ml.Transformer
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions.{col, udf}
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
|
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A feature transformer that filters out stop words from input.
|
* A feature transformer that filters out stop words from input.
|
||||||
*
|
*
|
||||||
|
* Since 3.0.0, `StopWordsRemover` can filter out multiple columns at once by setting the
|
||||||
|
* `inputCols` parameter. Note that when both the `inputCol` and `inputCols` parameters are set,
|
||||||
|
* an Exception will be thrown.
|
||||||
|
*
|
||||||
* @note null values from input array are preserved unless adding null to stopWords
|
* @note null values from input array are preserved unless adding null to stopWords
|
||||||
* explicitly.
|
* explicitly.
|
||||||
*
|
*
|
||||||
|
@ -38,7 +42,8 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
|
||||||
*/
|
*/
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String)
|
class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String)
|
||||||
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
|
extends Transformer with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols
|
||||||
|
with DefaultParamsWritable {
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
def this() = this(Identifiable.randomUID("stopWords"))
|
def this() = this(Identifiable.randomUID("stopWords"))
|
||||||
|
@ -51,6 +56,14 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
|
/** @group setParam */
|
||||||
|
@Since("3.0.0")
|
||||||
|
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
|
||||||
|
|
||||||
|
/** @group setParam */
|
||||||
|
@Since("3.0.0")
|
||||||
|
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The words to be filtered out.
|
* The words to be filtered out.
|
||||||
* Default: English stop words
|
* Default: English stop words
|
||||||
|
@ -121,6 +134,15 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Returns the input and output column names corresponding in pair. */
|
||||||
|
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
|
||||||
|
if (isSet(inputCol)) {
|
||||||
|
(Array($(inputCol)), Array($(outputCol)))
|
||||||
|
} else {
|
||||||
|
($(inputCols), $(outputCols))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
|
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
|
||||||
caseSensitive -> false, locale -> getDefaultOrUS.toString)
|
caseSensitive -> false, locale -> getDefaultOrUS.toString)
|
||||||
|
|
||||||
|
@ -142,16 +164,38 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
|
||||||
terms.filter(s => !lowerStopWords.contains(toLower(s)))
|
terms.filter(s => !lowerStopWords.contains(toLower(s)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val metadata = outputSchema($(outputCol)).metadata
|
|
||||||
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
|
val (inputColNames, outputColNames) = getInOutCols()
|
||||||
|
val ouputCols = inputColNames.map { inputColName =>
|
||||||
|
t(col(inputColName))
|
||||||
|
}
|
||||||
|
val ouputMetadata = outputColNames.map(outputSchema(_).metadata)
|
||||||
|
dataset.withColumns(outputColNames, ouputCols, ouputMetadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
val inputType = schema($(inputCol)).dataType
|
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol),
|
||||||
require(inputType.sameType(ArrayType(StringType)), "Input type must be " +
|
Seq(outputCols))
|
||||||
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
|
|
||||||
SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
|
if (isSet(inputCols)) {
|
||||||
|
require(getInputCols.length == getOutputCols.length,
|
||||||
|
s"StopWordsRemover $this has mismatched Params " +
|
||||||
|
s"for multi-column transform. Params ($inputCols, $outputCols) should have " +
|
||||||
|
"equal lengths, but they have different lengths: " +
|
||||||
|
s"(${getInputCols.length}, ${getOutputCols.length}).")
|
||||||
|
}
|
||||||
|
|
||||||
|
val (inputColNames, outputColNames) = getInOutCols()
|
||||||
|
val newCols = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
|
||||||
|
require(!schema.fieldNames.contains(outputColName),
|
||||||
|
s"Output Column $outputColName already exists.")
|
||||||
|
val inputType = schema(inputColName).dataType
|
||||||
|
require(inputType.sameType(ArrayType(StringType)), "Input type must be " +
|
||||||
|
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
|
||||||
|
StructField(outputColName, inputType, schema(inputColName).nullable)
|
||||||
|
}
|
||||||
|
StructType(schema.fields ++ newCols)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import java.util.Locale
|
import java.util.Locale
|
||||||
|
|
||||||
|
import org.apache.spark.ml.Pipeline
|
||||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
|
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
|
||||||
|
@ -181,12 +182,19 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
test("read/write") {
|
test("read/write") {
|
||||||
val t = new StopWordsRemover()
|
val t1 = new StopWordsRemover()
|
||||||
.setInputCol("myInputCol")
|
.setInputCol("myInputCol")
|
||||||
.setOutputCol("myOutputCol")
|
.setOutputCol("myOutputCol")
|
||||||
.setStopWords(Array("the", "a"))
|
.setStopWords(Array("the", "a"))
|
||||||
.setCaseSensitive(true)
|
.setCaseSensitive(true)
|
||||||
testDefaultReadWrite(t)
|
testDefaultReadWrite(t1)
|
||||||
|
|
||||||
|
val t2 = new StopWordsRemover()
|
||||||
|
.setInputCols(Array("input1", "input2", "input3"))
|
||||||
|
.setOutputCols(Array("result1", "result2", "result3"))
|
||||||
|
.setStopWords(Array("the", "a"))
|
||||||
|
.setCaseSensitive(true)
|
||||||
|
testDefaultReadWrite(t2)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("StopWordsRemover output column already exists") {
|
test("StopWordsRemover output column already exists") {
|
||||||
|
@ -199,7 +207,7 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
|
||||||
testTransformerByInterceptingException[(Array[String], Array[String])](
|
testTransformerByInterceptingException[(Array[String], Array[String])](
|
||||||
dataSet,
|
dataSet,
|
||||||
remover,
|
remover,
|
||||||
s"requirement failed: Column $outputCol already exists.",
|
s"requirement failed: Output Column $outputCol already exists.",
|
||||||
"expected")
|
"expected")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,4 +225,123 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
|
||||||
Locale.setDefault(oldDefault)
|
Locale.setDefault(oldDefault)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Multiple Columns: StopWordsRemover default") {
|
||||||
|
val remover = new StopWordsRemover()
|
||||||
|
.setInputCols(Array("raw1", "raw2"))
|
||||||
|
.setOutputCols(Array("filtered1", "filtered2"))
|
||||||
|
val df = Seq(
|
||||||
|
(Seq("test", "test"), Seq("test1", "test2"), Seq("test", "test"), Seq("test1", "test2")),
|
||||||
|
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")),
|
||||||
|
(Seq("a", "the", "an"), Seq("the", "an"), Seq(), Seq()),
|
||||||
|
(Seq("A", "The", "AN"), Seq("A", "The"), Seq(), Seq()),
|
||||||
|
(Seq(null), Seq(null), Seq(null), Seq(null)),
|
||||||
|
(Seq(), Seq(), Seq(), Seq())
|
||||||
|
).toDF("raw1", "raw2", "expected1", "expected2")
|
||||||
|
|
||||||
|
remover.transform(df)
|
||||||
|
.select("filtered1", "expected1", "filtered2", "expected2")
|
||||||
|
.collect().foreach {
|
||||||
|
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) =>
|
||||||
|
assert(r1 === e1,
|
||||||
|
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
|
||||||
|
assert(r2 === e2,
|
||||||
|
s"The result value is not correct after bucketing. Expected $e2 but found $r2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Multiple Columns: StopWordsRemover with particular stop words list") {
|
||||||
|
val stopWords = Array("test", "a", "an", "the")
|
||||||
|
val remover = new StopWordsRemover()
|
||||||
|
.setInputCols(Array("raw1", "raw2"))
|
||||||
|
.setOutputCols(Array("filtered1", "filtered2"))
|
||||||
|
.setStopWords(stopWords)
|
||||||
|
val df = Seq(
|
||||||
|
(Seq("test", "test"), Seq("test1", "test2"), Seq(), Seq("test1", "test2")),
|
||||||
|
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")),
|
||||||
|
(Seq("a", "the", "an"), Seq("a", "the", "test1"), Seq(), Seq("test1")),
|
||||||
|
(Seq("A", "The", "AN"), Seq("A", "The", "AN"), Seq(), Seq()),
|
||||||
|
(Seq(null), Seq(null), Seq(null), Seq(null)),
|
||||||
|
(Seq(), Seq(), Seq(), Seq())
|
||||||
|
).toDF("raw1", "raw2", "expected1", "expected2")
|
||||||
|
|
||||||
|
remover.transform(df)
|
||||||
|
.select("filtered1", "expected1", "filtered2", "expected2")
|
||||||
|
.collect().foreach {
|
||||||
|
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) =>
|
||||||
|
assert(r1 === e1,
|
||||||
|
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
|
||||||
|
assert(r2 === e2,
|
||||||
|
s"The result value is not correct after bucketing. Expected $e2 but found $r2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Compare single/multiple column(s) StopWordsRemover in pipeline") {
|
||||||
|
val df = Seq(
|
||||||
|
(Seq("test", "test"), Seq("test1", "test2")),
|
||||||
|
(Seq("a", "b", "c", "d"), Seq("a", "b")),
|
||||||
|
(Seq("a", "the", "an"), Seq("a", "the", "test1")),
|
||||||
|
(Seq("A", "The", "AN"), Seq("A", "The", "AN")),
|
||||||
|
(Seq(null), Seq(null)),
|
||||||
|
(Seq(), Seq())
|
||||||
|
).toDF("input1", "input2")
|
||||||
|
|
||||||
|
val multiColsRemover = new StopWordsRemover()
|
||||||
|
.setInputCols(Array("input1", "input2"))
|
||||||
|
.setOutputCols(Array("output1", "output2"))
|
||||||
|
|
||||||
|
val plForMultiCols = new Pipeline()
|
||||||
|
.setStages(Array(multiColsRemover))
|
||||||
|
.fit(df)
|
||||||
|
|
||||||
|
val removerForCol1 = new StopWordsRemover()
|
||||||
|
.setInputCol("input1")
|
||||||
|
.setOutputCol("output1")
|
||||||
|
val removerForCol2 = new StopWordsRemover()
|
||||||
|
.setInputCol("input2")
|
||||||
|
.setOutputCol("output2")
|
||||||
|
|
||||||
|
val plForSingleCol = new Pipeline()
|
||||||
|
.setStages(Array(removerForCol1, removerForCol2))
|
||||||
|
.fit(df)
|
||||||
|
|
||||||
|
val resultForSingleCol = plForSingleCol.transform(df)
|
||||||
|
.select("output1", "output2")
|
||||||
|
.collect()
|
||||||
|
val resultForMultiCols = plForMultiCols.transform(df)
|
||||||
|
.select("output1", "output2")
|
||||||
|
.collect()
|
||||||
|
|
||||||
|
resultForSingleCol.zip(resultForMultiCols).foreach {
|
||||||
|
case (rowForSingle, rowForMultiCols) =>
|
||||||
|
assert(rowForSingle === rowForMultiCols)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Multiple Columns: Mismatched sizes of inputCols/outputCols") {
|
||||||
|
val remover = new StopWordsRemover()
|
||||||
|
.setInputCols(Array("input1"))
|
||||||
|
.setOutputCols(Array("result1", "result2"))
|
||||||
|
val df = Seq(
|
||||||
|
(Seq("A"), Seq("A")),
|
||||||
|
(Seq("The", "the"), Seq("The"))
|
||||||
|
).toDF("input1", "input2")
|
||||||
|
intercept[IllegalArgumentException] {
|
||||||
|
remover.transform(df).count()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Multiple Columns: Set both of inputCol/inputCols") {
|
||||||
|
val remover = new StopWordsRemover()
|
||||||
|
.setInputCols(Array("input1", "input2"))
|
||||||
|
.setOutputCols(Array("result1", "result2"))
|
||||||
|
.setInputCol("input1")
|
||||||
|
val df = Seq(
|
||||||
|
(Seq("A"), Seq("A")),
|
||||||
|
(Seq("The", "the"), Seq("The"))
|
||||||
|
).toDF("input1", "input2")
|
||||||
|
intercept[IllegalArgumentException] {
|
||||||
|
remover.transform(df).count()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3774,9 +3774,13 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
|
||||||
return self._set(outputCol=value)
|
return self._set(outputCol=value)
|
||||||
|
|
||||||
|
|
||||||
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
|
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
|
||||||
|
JavaMLReadable, JavaMLWritable):
|
||||||
"""
|
"""
|
||||||
A feature transformer that filters out stop words from input.
|
A feature transformer that filters out stop words from input.
|
||||||
|
Since 3.0.0, :py:class:`StopWordsRemover` can filter out multiple columns at once by setting
|
||||||
|
the :py:attr:`inputCols` parameter. Note that when both the :py:attr:`inputCol` and
|
||||||
|
:py:attr:`inputCols` parameters are set, an Exception will be thrown.
|
||||||
|
|
||||||
.. note:: null values from input array are preserved unless adding null to stopWords explicitly.
|
.. note:: null values from input array are preserved unless adding null to stopWords explicitly.
|
||||||
|
|
||||||
|
@ -3795,6 +3799,17 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
|
||||||
True
|
True
|
||||||
>>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive()
|
>>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive()
|
||||||
True
|
True
|
||||||
|
>>> df2 = spark.createDataFrame([(["a", "b", "c"], ["a", "b"])], ["text1", "text2"])
|
||||||
|
>>> remover2 = StopWordsRemover(stopWords=["b"])
|
||||||
|
>>> remover2.setInputCols(["text1", "text2"]).setOutputCols(["words1", "words2"])
|
||||||
|
StopWordsRemover...
|
||||||
|
>>> remover2.transform(df2).show()
|
||||||
|
+---------+------+------+------+
|
||||||
|
| text1| text2|words1|words2|
|
||||||
|
+---------+------+------+------+
|
||||||
|
|[a, b, c]|[a, b]|[a, c]| [a]|
|
||||||
|
+---------+------+------+------+
|
||||||
|
...
|
||||||
|
|
||||||
.. versionadded:: 1.6.0
|
.. versionadded:: 1.6.0
|
||||||
"""
|
"""
|
||||||
|
@ -3808,10 +3823,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
|
||||||
|
|
||||||
@keyword_only
|
@keyword_only
|
||||||
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
|
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
|
||||||
locale=None):
|
locale=None, inputCols=None, outputCols=None):
|
||||||
"""
|
"""
|
||||||
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
|
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
|
||||||
locale=None)
|
locale=None, inputCols=None, outputCols=None)
|
||||||
"""
|
"""
|
||||||
super(StopWordsRemover, self).__init__()
|
super(StopWordsRemover, self).__init__()
|
||||||
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
|
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
|
||||||
|
@ -3824,10 +3839,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
|
||||||
@keyword_only
|
@keyword_only
|
||||||
@since("1.6.0")
|
@since("1.6.0")
|
||||||
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
|
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
|
||||||
locale=None):
|
locale=None, inputCols=None, outputCols=None):
|
||||||
"""
|
"""
|
||||||
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
|
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
|
||||||
locale=None)
|
locale=None, inputCols=None, outputCols=None)
|
||||||
Sets params for this StopWordRemover.
|
Sets params for this StopWordRemover.
|
||||||
"""
|
"""
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
|
@ -3887,6 +3902,20 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
|
||||||
"""
|
"""
|
||||||
return self._set(outputCol=value)
|
return self._set(outputCol=value)
|
||||||
|
|
||||||
|
@since("3.0.0")
|
||||||
|
def setInputCols(self, value):
|
||||||
|
"""
|
||||||
|
Sets the value of :py:attr:`inputCols`.
|
||||||
|
"""
|
||||||
|
return self._set(inputCols=value)
|
||||||
|
|
||||||
|
@since("3.0.0")
|
||||||
|
def setOutputCols(self, value):
|
||||||
|
"""
|
||||||
|
Sets the value of :py:attr:`outputCols`.
|
||||||
|
"""
|
||||||
|
return self._set(outputCols=value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@since("2.0.0")
|
@since("2.0.0")
|
||||||
def loadDefaultStopWords(language):
|
def loadDefaultStopWords(language):
|
||||||
|
|
Loading…
Reference in a new issue