[SPARK-15064][ML] Locale support in StopWordsRemover
## What changes were proposed in this pull request? Add locale support for `StopWordsRemover`. ## How was this patch tested? [Scala|Python] unit tests. Author: Lee Dongjin <dongjin@apache.org> Closes #21501 from dongjinleekr/feature/SPARK-15064.
This commit is contained in:
parent
1d7db65e96
commit
5d6a53d983
|
@ -17,9 +17,11 @@
|
|||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import java.util.Locale
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.Transformer
|
||||
import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
|
@ -84,7 +86,27 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
|
|||
@Since("1.5.0")
|
||||
def getCaseSensitive: Boolean = $(caseSensitive)
|
||||
|
||||
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false)
|
||||
/**
|
||||
* Locale of the input for case insensitive matching. Ignored when [[caseSensitive]]
|
||||
* is true.
|
||||
* Default: Locale.getDefault.toString
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.4.0")
|
||||
val locale: Param[String] = new Param[String](this, "locale",
|
||||
"Locale of the input for case insensitive matching. Ignored when caseSensitive is true.",
|
||||
ParamValidators.inArray[String](Locale.getAvailableLocales.map(_.toString)))
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.4.0")
|
||||
def setLocale(value: String): this.type = set(locale, value)
|
||||
|
||||
/** @group getParam */
|
||||
@Since("2.4.0")
|
||||
def getLocale: String = $(locale)
|
||||
|
||||
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
|
||||
caseSensitive -> false, locale -> Locale.getDefault.toString)
|
||||
|
||||
@Since("2.0.0")
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
|
@ -95,8 +117,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
|
|||
terms.filter(s => !stopWordsSet.contains(s))
|
||||
}
|
||||
} else {
|
||||
// TODO: support user locale (SPARK-15064)
|
||||
val toLower = (s: String) => if (s != null) s.toLowerCase else s
|
||||
val lc = new Locale($(locale))
|
||||
val toLower = (s: String) => if (s != null) s.toLowerCase(lc) else s
|
||||
val lowerStopWords = $(stopWords).map(toLower(_)).toSet
|
||||
udf { terms: Seq[String] =>
|
||||
terms.filter(s => !lowerStopWords.contains(toLower(s)))
|
||||
|
|
|
@ -65,6 +65,57 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
|
|||
testStopWordsRemover(remover, dataSet)
|
||||
}
|
||||
|
||||
test("StopWordsRemover with localed input (case insensitive)") {
|
||||
val stopWords = Array("milk", "cookie")
|
||||
val remover = new StopWordsRemover()
|
||||
.setInputCol("raw")
|
||||
.setOutputCol("filtered")
|
||||
.setStopWords(stopWords)
|
||||
.setCaseSensitive(false)
|
||||
.setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's.
|
||||
val dataSet = Seq(
|
||||
// scalastyle:off
|
||||
(Seq("mİlk", "and", "nuts"), Seq("and", "nuts")),
|
||||
// scalastyle:on
|
||||
(Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")),
|
||||
(Seq(null), Seq(null)),
|
||||
(Seq(), Seq())
|
||||
).toDF("raw", "expected")
|
||||
|
||||
testStopWordsRemover(remover, dataSet)
|
||||
}
|
||||
|
||||
test("StopWordsRemover with localed input (case sensitive)") {
|
||||
val stopWords = Array("milk", "cookie")
|
||||
val remover = new StopWordsRemover()
|
||||
.setInputCol("raw")
|
||||
.setOutputCol("filtered")
|
||||
.setStopWords(stopWords)
|
||||
.setCaseSensitive(true)
|
||||
.setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's.
|
||||
val dataSet = Seq(
|
||||
// scalastyle:off
|
||||
(Seq("mİlk", "and", "nuts"), Seq("mİlk", "and", "nuts")),
|
||||
// scalastyle:on
|
||||
(Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")),
|
||||
(Seq(null), Seq(null)),
|
||||
(Seq(), Seq())
|
||||
).toDF("raw", "expected")
|
||||
|
||||
testStopWordsRemover(remover, dataSet)
|
||||
}
|
||||
|
||||
test("StopWordsRemover with invalid locale") {
|
||||
intercept[IllegalArgumentException] {
|
||||
val stopWords = Array("test", "a", "an", "the")
|
||||
new StopWordsRemover()
|
||||
.setInputCol("raw")
|
||||
.setOutputCol("filtered")
|
||||
.setStopWords(stopWords)
|
||||
.setLocale("rt") // invalid locale
|
||||
}
|
||||
}
|
||||
|
||||
test("StopWordsRemover case sensitive") {
|
||||
val remover = new StopWordsRemover()
|
||||
.setInputCol("raw")
|
||||
|
|
|
@ -2582,25 +2582,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
|
|||
typeConverter=TypeConverters.toListString)
|
||||
caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
|
||||
"comparison over the stop words", typeConverter=TypeConverters.toBoolean)
|
||||
locale = Param(Params._dummy(), "locale", "locale of the input. ignored when case sensitive " +
|
||||
"is true", typeConverter=TypeConverters.toString)
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
|
||||
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
|
||||
locale=None):
|
||||
"""
|
||||
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
|
||||
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
|
||||
locale=None)
|
||||
"""
|
||||
super(StopWordsRemover, self).__init__()
|
||||
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
|
||||
self.uid)
|
||||
self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
|
||||
caseSensitive=False)
|
||||
caseSensitive=False, locale=self._java_obj.getLocale())
|
||||
kwargs = self._input_kwargs
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@keyword_only
|
||||
@since("1.6.0")
|
||||
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
|
||||
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
|
||||
locale=None):
|
||||
"""
|
||||
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
|
||||
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
|
||||
locale=None)
|
||||
Sets params for this StopWordRemover.
|
||||
"""
|
||||
kwargs = self._input_kwargs
|
||||
|
@ -2634,6 +2640,20 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
|
|||
"""
|
||||
return self.getOrDefault(self.caseSensitive)
|
||||
|
||||
@since("2.4.0")
|
||||
def setLocale(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`locale`.
|
||||
"""
|
||||
return self._set(locale=value)
|
||||
|
||||
@since("2.4.0")
|
||||
def getLocale(self):
|
||||
"""
|
||||
Gets the value of :py:attr:`locale`.
|
||||
"""
|
||||
return self.getOrDefault(self.locale)
|
||||
|
||||
@staticmethod
|
||||
@since("2.0.0")
|
||||
def loadDefaultStopWords(language):
|
||||
|
|
|
@ -681,6 +681,13 @@ class FeatureTests(SparkSessionTestCase):
|
|||
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
|
||||
transformedDF = stopWordRemover.transform(dataset)
|
||||
self.assertEqual(transformedDF.head().output, [])
|
||||
# with locale
|
||||
stopwords = ["BELKİ"]
|
||||
dataset = self.spark.createDataFrame([Row(input=["belki"])])
|
||||
stopWordRemover.setStopWords(stopwords).setLocale("tr")
|
||||
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
|
||||
transformedDF = stopWordRemover.transform(dataset)
|
||||
self.assertEqual(transformedDF.head().output, [])
|
||||
|
||||
def test_count_vectorizer_with_binary(self):
|
||||
dataset = self.spark.createDataFrame([
|
||||
|
|
Loading…
Reference in a new issue