[SPARK-8455] [ML] Implement n-gram feature transformer
Implementation of n-gram feature transformer for ML. Author: Feynman Liang <fliang@databricks.com> Closes #6887 from feynmanliang/ngram-featurizer and squashes the following commits: d2c839f [Feynman Liang] Make n > input length yield empty output 9fadd36 [Feynman Liang] Add empty and corner test cases, fix names and spaces fe93873 [Feynman Liang] Implement n-gram feature transformer
This commit is contained in:
parent
5ab9fcfb01
commit
afe35f0519
69
mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
Normal file
69
mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
Normal file
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.ml.UnaryTransformer
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* A feature transformer that converts the input array of strings into an array of n-grams. Null
|
||||
* values in the input array are ignored.
|
||||
* It returns an array of n-grams where each n-gram is represented by a space-separated string of
|
||||
* words.
|
||||
*
|
||||
* When the input is empty, an empty array is returned.
|
||||
* When the input array length is less than n (number of elements per n-gram), no n-grams are
|
||||
* returned.
|
||||
*/
|
||||
@Experimental
|
||||
class NGram(override val uid: String)
|
||||
extends UnaryTransformer[Seq[String], Seq[String], NGram] {
|
||||
|
||||
def this() = this(Identifiable.randomUID("ngram"))
|
||||
|
||||
/**
|
||||
* Minimum n-gram length, >= 1.
|
||||
* Default: 2, bigram features
|
||||
* @group param
|
||||
*/
|
||||
val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
/** @group setParam */
|
||||
def setN(value: Int): this.type = set(n, value)
|
||||
|
||||
/** @group getParam */
|
||||
def getN: Int = $(n)
|
||||
|
||||
setDefault(n -> 2)
|
||||
|
||||
override protected def createTransformFunc: Seq[String] => Seq[String] = {
|
||||
_.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
|
||||
}
|
||||
|
||||
override protected def validateInputType(inputType: DataType): Unit = {
|
||||
require(inputType.sameType(ArrayType(StringType)),
|
||||
s"Input type must be ArrayType(StringType) but got $inputType.")
|
||||
}
|
||||
|
||||
override protected def outputDataType: DataType = new ArrayType(StringType, false)
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import scala.beans.BeanInfo
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
|
||||
@BeanInfo
|
||||
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
|
||||
|
||||
class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
import org.apache.spark.ml.feature.NGramSuite._
|
||||
|
||||
test("default behavior yields bigram features") {
|
||||
val nGram = new NGram()
|
||||
.setInputCol("inputTokens")
|
||||
.setOutputCol("nGrams")
|
||||
val dataset = sqlContext.createDataFrame(Seq(
|
||||
NGramTestData(
|
||||
Array("Test", "for", "ngram", "."),
|
||||
Array("Test for", "for ngram", "ngram .")
|
||||
)))
|
||||
testNGram(nGram, dataset)
|
||||
}
|
||||
|
||||
test("NGramLength=4 yields length 4 n-grams") {
|
||||
val nGram = new NGram()
|
||||
.setInputCol("inputTokens")
|
||||
.setOutputCol("nGrams")
|
||||
.setN(4)
|
||||
val dataset = sqlContext.createDataFrame(Seq(
|
||||
NGramTestData(
|
||||
Array("a", "b", "c", "d", "e"),
|
||||
Array("a b c d", "b c d e")
|
||||
)))
|
||||
testNGram(nGram, dataset)
|
||||
}
|
||||
|
||||
test("empty input yields empty output") {
|
||||
val nGram = new NGram()
|
||||
.setInputCol("inputTokens")
|
||||
.setOutputCol("nGrams")
|
||||
.setN(4)
|
||||
val dataset = sqlContext.createDataFrame(Seq(
|
||||
NGramTestData(
|
||||
Array(),
|
||||
Array()
|
||||
)))
|
||||
testNGram(nGram, dataset)
|
||||
}
|
||||
|
||||
test("input array < n yields empty output") {
|
||||
val nGram = new NGram()
|
||||
.setInputCol("inputTokens")
|
||||
.setOutputCol("nGrams")
|
||||
.setN(6)
|
||||
val dataset = sqlContext.createDataFrame(Seq(
|
||||
NGramTestData(
|
||||
Array("a", "b", "c", "d", "e"),
|
||||
Array()
|
||||
)))
|
||||
testNGram(nGram, dataset)
|
||||
}
|
||||
}
|
||||
|
||||
object NGramSuite extends SparkFunSuite {
|
||||
|
||||
def testNGram(t: NGram, dataset: DataFrame): Unit = {
|
||||
t.transform(dataset)
|
||||
.select("nGrams", "wantedNGrams")
|
||||
.collect()
|
||||
.foreach { case Row(actualNGrams, wantedNGrams) =>
|
||||
assert(actualNGrams === wantedNGrams)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue