[SPARK-10393] use ML pipeline in LDA example
jira: https://issues.apache.org/jira/browse/SPARK-10393 Since the logic of the text processing part has been moved to ML estimators/transformers, replace the related code in LDA Example with the ML pipeline. Author: Yuhao Yang <hhbyyh@gmail.com> Author: yuhaoyang <yuhao@zhanglipings-iMac.local> Closes #8551 from hhbyyh/ldaExUpdate.
This commit is contained in:
parent
5d96a710a5
commit
872a2ee281
|
@ -18,19 +18,16 @@
|
|||
// scalastyle:off println
|
||||
package org.apache.spark.examples.mllib
|
||||
|
||||
import java.text.BreakIterator
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import scopt.OptionParser
|
||||
|
||||
import org.apache.log4j.{Level, Logger}
|
||||
|
||||
import org.apache.spark.{SparkContext, SparkConf}
|
||||
import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA}
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.Pipeline
|
||||
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
|
||||
import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer}
|
||||
import org.apache.spark.mllib.linalg.Vector
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
|
||||
/**
|
||||
* An example Latent Dirichlet Allocation (LDA) app. Run with
|
||||
|
@ -192,115 +189,45 @@ object LDAExample {
|
|||
vocabSize: Int,
|
||||
stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = {
|
||||
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// Get dataset of document texts
|
||||
// One document per line in each text file. If the input consists of many small files,
|
||||
// this can result in a large number of small partitions, which can degrade performance.
|
||||
// In this case, consider using coalesce() to create fewer, larger partitions.
|
||||
val textRDD: RDD[String] = sc.textFile(paths.mkString(","))
|
||||
|
||||
// Split text into words
|
||||
val tokenizer = new SimpleTokenizer(sc, stopwordFile)
|
||||
val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) =>
|
||||
id -> tokenizer.getWords(text)
|
||||
val df = sc.textFile(paths.mkString(",")).toDF("docs")
|
||||
val customizedStopWords: Array[String] = if (stopwordFile.isEmpty) {
|
||||
Array.empty[String]
|
||||
} else {
|
||||
val stopWordText = sc.textFile(stopwordFile).collect()
|
||||
stopWordText.flatMap(_.stripMargin.split("\\s+"))
|
||||
}
|
||||
tokenized.cache()
|
||||
val tokenizer = new RegexTokenizer()
|
||||
.setInputCol("docs")
|
||||
.setOutputCol("rawTokens")
|
||||
val stopWordsRemover = new StopWordsRemover()
|
||||
.setInputCol("rawTokens")
|
||||
.setOutputCol("tokens")
|
||||
stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
|
||||
val countVectorizer = new CountVectorizer()
|
||||
.setVocabSize(vocabSize)
|
||||
.setInputCol("tokens")
|
||||
.setOutputCol("features")
|
||||
|
||||
// Counts words: RDD[(word, wordCount)]
|
||||
val wordCounts: RDD[(String, Long)] = tokenized
|
||||
.flatMap { case (_, tokens) => tokens.map(_ -> 1L) }
|
||||
.reduceByKey(_ + _)
|
||||
wordCounts.cache()
|
||||
val fullVocabSize = wordCounts.count()
|
||||
// Select vocab
|
||||
// (vocab: Map[word -> id], total tokens after selecting vocab)
|
||||
val (vocab: Map[String, Int], selectedTokenCount: Long) = {
|
||||
val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) {
|
||||
// Use all terms
|
||||
wordCounts.collect().sortBy(-_._2)
|
||||
} else {
|
||||
// Sort terms to select vocab
|
||||
wordCounts.sortBy(_._2, ascending = false).take(vocabSize)
|
||||
}
|
||||
(tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum)
|
||||
}
|
||||
val pipeline = new Pipeline()
|
||||
.setStages(Array(tokenizer, stopWordsRemover, countVectorizer))
|
||||
|
||||
val documents = tokenized.map { case (id, tokens) =>
|
||||
// Filter tokens by vocabulary, and create word count vector representation of document.
|
||||
val wc = new mutable.HashMap[Int, Int]()
|
||||
tokens.foreach { term =>
|
||||
if (vocab.contains(term)) {
|
||||
val termIndex = vocab(term)
|
||||
wc(termIndex) = wc.getOrElse(termIndex, 0) + 1
|
||||
}
|
||||
}
|
||||
val indices = wc.keys.toArray.sorted
|
||||
val values = indices.map(i => wc(i).toDouble)
|
||||
val model = pipeline.fit(df)
|
||||
val documents = model.transform(df)
|
||||
.select("features")
|
||||
.map { case Row(features: Vector) => features }
|
||||
.zipWithIndex()
|
||||
.map(_.swap)
|
||||
|
||||
val sb = Vectors.sparse(vocab.size, indices, values)
|
||||
(id, sb)
|
||||
}
|
||||
|
||||
val vocabArray = new Array[String](vocab.size)
|
||||
vocab.foreach { case (term, i) => vocabArray(i) = term }
|
||||
|
||||
(documents, vocabArray, selectedTokenCount)
|
||||
(documents,
|
||||
model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary
|
||||
documents.map(_._2.numActives).sum().toLong) // total token count
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple Tokenizer.
|
||||
*
|
||||
* TODO: Formalize the interface, and make this a public class in mllib.feature
|
||||
*/
|
||||
private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable {
|
||||
|
||||
private val stopwords: Set[String] = if (stopwordFile.isEmpty) {
|
||||
Set.empty[String]
|
||||
} else {
|
||||
val stopwordText = sc.textFile(stopwordFile).collect()
|
||||
stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet
|
||||
}
|
||||
|
||||
// Matches sequences of Unicode letters
|
||||
private val allWordRegex = "^(\\p{L}*)$".r
|
||||
|
||||
// Ignore words shorter than this length.
|
||||
private val minWordLength = 3
|
||||
|
||||
def getWords(text: String): IndexedSeq[String] = {
|
||||
|
||||
val words = new mutable.ArrayBuffer[String]()
|
||||
|
||||
// Use Java BreakIterator to tokenize text into words.
|
||||
val wb = BreakIterator.getWordInstance
|
||||
wb.setText(text)
|
||||
|
||||
// current,end index start,end of each word
|
||||
var current = wb.first()
|
||||
var end = wb.next()
|
||||
while (end != BreakIterator.DONE) {
|
||||
// Convert to lowercase
|
||||
val word: String = text.substring(current, end).toLowerCase
|
||||
// Remove short words and strings that aren't only letters
|
||||
word match {
|
||||
case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) =>
|
||||
words += w
|
||||
case _ =>
|
||||
}
|
||||
|
||||
current = end
|
||||
try {
|
||||
end = wb.next()
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
// Ignore remaining text in line.
|
||||
// This is a known bug in BreakIterator (for some Java versions),
|
||||
// which fails when it sees certain characters.
|
||||
end = BreakIterator.DONE
|
||||
}
|
||||
}
|
||||
words
|
||||
}
|
||||
|
||||
}
|
||||
// scalastyle:on println
|
||||
|
|
Loading…
Reference in a new issue