[SPARK-15064][ML] Locale support in StopWordsRemover

## What changes were proposed in this pull request?

Add locale support for `StopWordsRemover`.

## How was this patch tested?

[Scala|Python] unit tests.

Author: Lee Dongjin <dongjin@apache.org>

Closes #21501 from dongjinleekr/feature/SPARK-15064.
This commit is contained in:
Lee Dongjin 2018-06-12 08:16:37 -07:00 committed by Xiangrui Meng
parent 1d7db65e96
commit 5d6a53d983
4 changed files with 109 additions and 9 deletions

View file

@ -17,9 +17,11 @@
package org.apache.spark.ml.feature
import java.util.Locale
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
@ -84,7 +86,27 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
@Since("1.5.0")
def getCaseSensitive: Boolean = $(caseSensitive)
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false)
/**
* Locale of the input for case insensitive matching. Ignored when [[caseSensitive]]
* is true.
* Default: Locale.getDefault.toString
* @group param
*/
@Since("2.4.0")
val locale: Param[String] = new Param[String](this, "locale",
"Locale of the input for case insensitive matching. Ignored when caseSensitive is true.",
ParamValidators.inArray[String](Locale.getAvailableLocales.map(_.toString)))
/** @group setParam */
@Since("2.4.0")
def setLocale(value: String): this.type = set(locale, value)
/** @group getParam */
@Since("2.4.0")
def getLocale: String = $(locale)
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive -> false, locale -> Locale.getDefault.toString)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
@ -95,8 +117,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
terms.filter(s => !stopWordsSet.contains(s))
}
} else {
// TODO: support user locale (SPARK-15064)
val toLower = (s: String) => if (s != null) s.toLowerCase else s
val lc = new Locale($(locale))
val toLower = (s: String) => if (s != null) s.toLowerCase(lc) else s
val lowerStopWords = $(stopWords).map(toLower(_)).toSet
udf { terms: Seq[String] =>
terms.filter(s => !lowerStopWords.contains(toLower(s)))

View file

@ -65,6 +65,57 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
testStopWordsRemover(remover, dataSet)
}
test("StopWordsRemover with localed input (case insensitive)") {
val stopWords = Array("milk", "cookie")
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords)
.setCaseSensitive(false)
.setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's.
val dataSet = Seq(
// scalastyle:off
(Seq("mİlk", "and", "nuts"), Seq("and", "nuts")),
// scalastyle:on
(Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")),
(Seq(null), Seq(null)),
(Seq(), Seq())
).toDF("raw", "expected")
testStopWordsRemover(remover, dataSet)
}
test("StopWordsRemover with localed input (case sensitive)") {
val stopWords = Array("milk", "cookie")
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords)
.setCaseSensitive(true)
.setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's.
val dataSet = Seq(
// scalastyle:off
(Seq("mİlk", "and", "nuts"), Seq("mİlk", "and", "nuts")),
// scalastyle:on
(Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")),
(Seq(null), Seq(null)),
(Seq(), Seq())
).toDF("raw", "expected")
testStopWordsRemover(remover, dataSet)
}
test("StopWordsRemover with invalid locale") {
intercept[IllegalArgumentException] {
val stopWords = Array("test", "a", "an", "the")
new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords)
.setLocale("rt") // invalid locale
}
}
test("StopWordsRemover case sensitive") {
val remover = new StopWordsRemover()
.setInputCol("raw")

View file

@ -2582,25 +2582,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
typeConverter=TypeConverters.toListString)
caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
"comparison over the stop words", typeConverter=TypeConverters.toBoolean)
locale = Param(Params._dummy(), "locale", "locale of the input. ignored when case sensitive " +
"is true", typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
locale=None):
"""
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
locale=None)
"""
super(StopWordsRemover, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
self.uid)
self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive=False)
caseSensitive=False, locale=self._java_obj.getLocale())
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.6.0")
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
locale=None):
"""
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
locale=None)
Sets params for this StopWordRemover.
"""
kwargs = self._input_kwargs
@ -2634,6 +2640,20 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
"""
return self.getOrDefault(self.caseSensitive)
@since("2.4.0")
def setLocale(self, value):
"""
Sets the value of :py:attr:`locale`.
"""
return self._set(locale=value)
@since("2.4.0")
def getLocale(self):
"""
Gets the value of :py:attr:`locale`.
"""
return self.getOrDefault(self.locale)
@staticmethod
@since("2.0.0")
def loadDefaultStopWords(language):

View file

@ -681,6 +681,13 @@ class FeatureTests(SparkSessionTestCase):
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])
# with locale
stopwords = ["BELKİ"]
dataset = self.spark.createDataFrame([Row(input=["belki"])])
stopWordRemover.setStopWords(stopwords).setLocale("tr")
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])
def test_count_vectorizer_with_binary(self):
dataset = self.spark.createDataFrame([