[SPARK-9679] [ML] [PYSPARK] Add Python API for Stop Words Remover
Add a python API for the Stop Words Remover. Author: Holden Karau <holden@pigscanfly.ca> Closes #8118 from holdenk/SPARK-9679-python-StopWordsRemover.
This commit is contained in:
parent
391e6be0ae
commit
e6e483cc4d
|
@ -29,14 +29,14 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp
|
|||
/**
|
||||
* stop words list
|
||||
*/
|
||||
private object StopWords {
|
||||
private[spark] object StopWords {
|
||||
|
||||
/**
|
||||
* Use the same default stopwords list as scikit-learn.
|
||||
* The original list can be found from "Glasgow Information Retrieval Group"
|
||||
* [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]]
|
||||
*/
|
||||
val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again",
|
||||
val English = Array( "a", "about", "above", "across", "after", "afterwards", "again",
|
||||
"against", "all", "almost", "alone", "along", "already", "also", "although", "always",
|
||||
"am", "among", "amongst", "amoungst", "amount", "an", "and", "another",
|
||||
"any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are",
|
||||
|
@ -121,7 +121,7 @@ class StopWordsRemover(override val uid: String)
|
|||
/** @group getParam */
|
||||
def getCaseSensitive: Boolean = $(caseSensitive)
|
||||
|
||||
setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false)
|
||||
setDefault(stopWords -> StopWords.English, caseSensitive -> false)
|
||||
|
||||
override def transform(dataset: DataFrame): DataFrame = {
|
||||
val outputSchema = transformSchema(dataset.schema)
|
||||
|
|
|
@ -65,7 +65,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
}
|
||||
|
||||
test("StopWordsRemover with additional words") {
|
||||
val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala")
|
||||
val stopWords = StopWords.English ++ Array("python", "scala")
|
||||
val remover = new StopWordsRemover()
|
||||
.setInputCol("raw")
|
||||
.setOutputCol("filtered")
|
||||
|
|
|
@ -22,7 +22,7 @@ if sys.version > '3':
|
|||
from pyspark.rdd import ignore_unicode_prefix
|
||||
from pyspark.ml.param.shared import *
|
||||
from pyspark.ml.util import keyword_only
|
||||
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
|
||||
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
|
||||
from pyspark.mllib.common import inherit_doc
|
||||
from pyspark.mllib.linalg import _convert_to_vector
|
||||
|
||||
|
@ -30,7 +30,7 @@ __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF',
|
|||
'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer',
|
||||
'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer',
|
||||
'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec',
|
||||
'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel']
|
||||
'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover']
|
||||
|
||||
|
||||
@inherit_doc
|
||||
|
@ -933,6 +933,75 @@ class StringIndexerModel(JavaModel):
|
|||
"""
|
||||
|
||||
|
||||
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
A feature transformer that filters out stop words from input.
|
||||
Note: null values from input array are preserved unless adding null to stopWords explicitly.
|
||||
"""
|
||||
# a placeholder to make the stopwords show up in generated doc
|
||||
stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out")
|
||||
caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
|
||||
"comparison over the stop words")
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, inputCol=None, outputCol=None, stopWords=None,
|
||||
caseSensitive=False):
|
||||
"""
|
||||
__init__(self, inputCol=None, outputCol=None, stopWords=None,\
|
||||
caseSensitive=false)
|
||||
"""
|
||||
super(StopWordsRemover, self).__init__()
|
||||
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
|
||||
self.uid)
|
||||
self.stopWords = Param(self, "stopWords", "The words to be filtered out")
|
||||
self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " +
|
||||
"sensitive comparison over the stop words")
|
||||
stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
|
||||
defaultStopWords = stopWordsObj.English()
|
||||
self._setDefault(stopWords=defaultStopWords)
|
||||
kwargs = self.__init__._input_kwargs
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@keyword_only
|
||||
def setParams(self, inputCol=None, outputCol=None, stopWords=None,
|
||||
caseSensitive=False):
|
||||
"""
|
||||
setParams(self, inputCol="input", outputCol="output", stopWords=None,\
|
||||
caseSensitive=false)
|
||||
Sets params for this StopWordRemover.
|
||||
"""
|
||||
kwargs = self.setParams._input_kwargs
|
||||
return self._set(**kwargs)
|
||||
|
||||
def setStopWords(self, value):
|
||||
"""
|
||||
Specify the stopwords to be filtered.
|
||||
"""
|
||||
self._paramMap[self.stopWords] = value
|
||||
return self
|
||||
|
||||
def getStopWords(self):
|
||||
"""
|
||||
Get the stopwords.
|
||||
"""
|
||||
return self.getOrDefault(self.stopWords)
|
||||
|
||||
def setCaseSensitive(self, value):
|
||||
"""
|
||||
Set whether to do a case sensitive comparison over the stop words
|
||||
"""
|
||||
self._paramMap[self.caseSensitive] = value
|
||||
return self
|
||||
|
||||
def getCaseSensitive(self):
|
||||
"""
|
||||
Get whether to do a case sensitive comparison over the stop words.
|
||||
"""
|
||||
return self.getOrDefault(self.caseSensitive)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
@ignore_unicode_prefix
|
||||
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
|
||||
|
|
|
@ -31,7 +31,7 @@ else:
|
|||
import unittest
|
||||
|
||||
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
|
||||
from pyspark.sql import DataFrame, SQLContext
|
||||
from pyspark.sql import DataFrame, SQLContext, Row
|
||||
from pyspark.sql.functions import rand
|
||||
from pyspark.ml.evaluation import RegressionEvaluator
|
||||
from pyspark.ml.param import Param, Params
|
||||
|
@ -258,7 +258,7 @@ class FeatureTests(PySparkTestCase):
|
|||
def test_ngram(self):
|
||||
sqlContext = SQLContext(self.sc)
|
||||
dataset = sqlContext.createDataFrame([
|
||||
([["a", "b", "c", "d", "e"]])], ["input"])
|
||||
Row(input=["a", "b", "c", "d", "e"])])
|
||||
ngram0 = NGram(n=4, inputCol="input", outputCol="output")
|
||||
self.assertEqual(ngram0.getN(), 4)
|
||||
self.assertEqual(ngram0.getInputCol(), "input")
|
||||
|
@ -266,6 +266,22 @@ class FeatureTests(PySparkTestCase):
|
|||
transformedDF = ngram0.transform(dataset)
|
||||
self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])
|
||||
|
||||
def test_stopwordsremover(self):
|
||||
sqlContext = SQLContext(self.sc)
|
||||
dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])])
|
||||
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
|
||||
# Default
|
||||
self.assertEquals(stopWordRemover.getInputCol(), "input")
|
||||
transformedDF = stopWordRemover.transform(dataset)
|
||||
self.assertEquals(transformedDF.head().output, ["panda"])
|
||||
# Custom
|
||||
stopwords = ["panda"]
|
||||
stopWordRemover.setStopWords(stopwords)
|
||||
self.assertEquals(stopWordRemover.getInputCol(), "input")
|
||||
self.assertEquals(stopWordRemover.getStopWords(), stopwords)
|
||||
transformedDF = stopWordRemover.transform(dataset)
|
||||
self.assertEquals(transformedDF.head().output, ["a"])
|
||||
|
||||
|
||||
class HasInducedError(Params):
|
||||
|
||||
|
|
Loading…
Reference in a new issue