[SPARK-9679] [ML] [PYSPARK] Add Python API for Stop Words Remover

Add a python API for the Stop Words Remover.

Author: Holden Karau <holden@pigscanfly.ca>

Closes #8118 from holdenk/SPARK-9679-python-StopWordsRemover.
This commit is contained in:
Holden Karau 2015-09-01 10:48:57 -07:00 committed by Xiangrui Meng
parent 391e6be0ae
commit e6e483cc4d
4 changed files with 93 additions and 8 deletions

View file

@ -29,14 +29,14 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp
/**
* stop words list
*/
private object StopWords {
private[spark] object StopWords {
/**
* Use the same default stopwords list as scikit-learn.
* The original list can be found from "Glasgow Information Retrieval Group"
* [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]]
*/
val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again",
val English = Array( "a", "about", "above", "across", "after", "afterwards", "again",
"against", "all", "almost", "alone", "along", "already", "also", "although", "always",
"am", "among", "amongst", "amoungst", "amount", "an", "and", "another",
"any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are",
@ -121,7 +121,7 @@ class StopWordsRemover(override val uid: String)
/** @group getParam */
def getCaseSensitive: Boolean = $(caseSensitive)
setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false)
setDefault(stopWords -> StopWords.English, caseSensitive -> false)
override def transform(dataset: DataFrame): DataFrame = {
val outputSchema = transformSchema(dataset.schema)

View file

@ -65,7 +65,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("StopWordsRemover with additional words") {
val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala")
val stopWords = StopWords.English ++ Array("python", "scala")
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")

View file

@ -22,7 +22,7 @@ if sys.version > '3':
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector
@ -30,7 +30,7 @@ __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF',
'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer',
'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer',
'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec',
'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel']
'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover']
@inherit_doc
@ -933,6 +933,75 @@ class StringIndexerModel(JavaModel):
"""
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
"""
.. note:: Experimental
A feature transformer that filters out stop words from input.
Note: null values from input array are preserved unless adding null to stopWords explicitly.
"""
# a placeholder to make the stopwords show up in generated doc
stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out")
caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
"comparison over the stop words")
@keyword_only
def __init__(self, inputCol=None, outputCol=None, stopWords=None,
caseSensitive=False):
"""
__init__(self, inputCol=None, outputCol=None, stopWords=None,\
caseSensitive=false)
"""
super(StopWordsRemover, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
self.uid)
self.stopWords = Param(self, "stopWords", "The words to be filtered out")
self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " +
"sensitive comparison over the stop words")
stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
defaultStopWords = stopWordsObj.English()
self._setDefault(stopWords=defaultStopWords)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, inputCol=None, outputCol=None, stopWords=None,
caseSensitive=False):
"""
setParams(self, inputCol="input", outputCol="output", stopWords=None,\
caseSensitive=false)
Sets params for this StopWordRemover.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def setStopWords(self, value):
"""
Specify the stopwords to be filtered.
"""
self._paramMap[self.stopWords] = value
return self
def getStopWords(self):
"""
Get the stopwords.
"""
return self.getOrDefault(self.stopWords)
def setCaseSensitive(self, value):
"""
Set whether to do a case sensitive comparison over the stop words
"""
self._paramMap[self.caseSensitive] = value
return self
def getCaseSensitive(self):
"""
Get whether to do a case sensitive comparison over the stop words.
"""
return self.getOrDefault(self.caseSensitive)
@inherit_doc
@ignore_unicode_prefix
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):

View file

@ -31,7 +31,7 @@ else:
import unittest
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame, SQLContext
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.param import Param, Params
@ -258,7 +258,7 @@ class FeatureTests(PySparkTestCase):
def test_ngram(self):
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame([
([["a", "b", "c", "d", "e"]])], ["input"])
Row(input=["a", "b", "c", "d", "e"])])
ngram0 = NGram(n=4, inputCol="input", outputCol="output")
self.assertEqual(ngram0.getN(), 4)
self.assertEqual(ngram0.getInputCol(), "input")
@ -266,6 +266,22 @@ class FeatureTests(PySparkTestCase):
transformedDF = ngram0.transform(dataset)
self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])
def test_stopwordsremover(self):
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])])
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
# Default
self.assertEquals(stopWordRemover.getInputCol(), "input")
transformedDF = stopWordRemover.transform(dataset)
self.assertEquals(transformedDF.head().output, ["panda"])
# Custom
stopwords = ["panda"]
stopWordRemover.setStopWords(stopwords)
self.assertEquals(stopWordRemover.getInputCol(), "input")
self.assertEquals(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEquals(transformedDF.head().output, ["a"])
class HasInducedError(Params):