[SPARK-10393] use ML pipeline in LDA example

jira: https://issues.apache.org/jira/browse/SPARK-10393

Since the logic of the text processing part has been moved to ML estimators/transformers, replace the related code in LDA Example with the ML pipeline.

Author: Yuhao Yang <hhbyyh@gmail.com>
Author: yuhaoyang <yuhao@zhanglipings-iMac.local>

Closes #8551 from hhbyyh/ldaExUpdate.
This commit is contained in:
Yuhao Yang 2015-12-08 10:29:51 -08:00 committed by Joseph K. Bradley
parent 5d96a710a5
commit 872a2ee281

View file

@ -18,19 +18,16 @@
// scalastyle:off println
package org.apache.spark.examples.mllib
import java.text.BreakIterator
import scala.collection.mutable
import scopt.OptionParser
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
/**
* An example Latent Dirichlet Allocation (LDA) app. Run with
@ -192,115 +189,45 @@ object LDAExample {
vocabSize: Int,
stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = {
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
// Get dataset of document texts
// One document per line in each text file. If the input consists of many small files,
// this can result in a large number of small partitions, which can degrade performance.
// In this case, consider using coalesce() to create fewer, larger partitions.
val textRDD: RDD[String] = sc.textFile(paths.mkString(","))
// Split text into words
val tokenizer = new SimpleTokenizer(sc, stopwordFile)
val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) =>
id -> tokenizer.getWords(text)
val df = sc.textFile(paths.mkString(",")).toDF("docs")
val customizedStopWords: Array[String] = if (stopwordFile.isEmpty) {
Array.empty[String]
} else {
val stopWordText = sc.textFile(stopwordFile).collect()
stopWordText.flatMap(_.stripMargin.split("\\s+"))
}
tokenized.cache()
val tokenizer = new RegexTokenizer()
.setInputCol("docs")
.setOutputCol("rawTokens")
val stopWordsRemover = new StopWordsRemover()
.setInputCol("rawTokens")
.setOutputCol("tokens")
stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
val countVectorizer = new CountVectorizer()
.setVocabSize(vocabSize)
.setInputCol("tokens")
.setOutputCol("features")
// Counts words: RDD[(word, wordCount)]
val wordCounts: RDD[(String, Long)] = tokenized
.flatMap { case (_, tokens) => tokens.map(_ -> 1L) }
.reduceByKey(_ + _)
wordCounts.cache()
val fullVocabSize = wordCounts.count()
// Select vocab
// (vocab: Map[word -> id], total tokens after selecting vocab)
val (vocab: Map[String, Int], selectedTokenCount: Long) = {
val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) {
// Use all terms
wordCounts.collect().sortBy(-_._2)
} else {
// Sort terms to select vocab
wordCounts.sortBy(_._2, ascending = false).take(vocabSize)
}
(tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum)
}
val pipeline = new Pipeline()
.setStages(Array(tokenizer, stopWordsRemover, countVectorizer))
val documents = tokenized.map { case (id, tokens) =>
// Filter tokens by vocabulary, and create word count vector representation of document.
val wc = new mutable.HashMap[Int, Int]()
tokens.foreach { term =>
if (vocab.contains(term)) {
val termIndex = vocab(term)
wc(termIndex) = wc.getOrElse(termIndex, 0) + 1
}
}
val indices = wc.keys.toArray.sorted
val values = indices.map(i => wc(i).toDouble)
val model = pipeline.fit(df)
val documents = model.transform(df)
.select("features")
.map { case Row(features: Vector) => features }
.zipWithIndex()
.map(_.swap)
val sb = Vectors.sparse(vocab.size, indices, values)
(id, sb)
}
val vocabArray = new Array[String](vocab.size)
vocab.foreach { case (term, i) => vocabArray(i) = term }
(documents, vocabArray, selectedTokenCount)
(documents,
model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary
documents.map(_._2.numActives).sum().toLong) // total token count
}
}
/**
* Simple Tokenizer.
*
* TODO: Formalize the interface, and make this a public class in mllib.feature
*/
private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable {
private val stopwords: Set[String] = if (stopwordFile.isEmpty) {
Set.empty[String]
} else {
val stopwordText = sc.textFile(stopwordFile).collect()
stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet
}
// Matches sequences of Unicode letters
private val allWordRegex = "^(\\p{L}*)$".r
// Ignore words shorter than this length.
private val minWordLength = 3
def getWords(text: String): IndexedSeq[String] = {
val words = new mutable.ArrayBuffer[String]()
// Use Java BreakIterator to tokenize text into words.
val wb = BreakIterator.getWordInstance
wb.setText(text)
// current,end index start,end of each word
var current = wb.first()
var end = wb.next()
while (end != BreakIterator.DONE) {
// Convert to lowercase
val word: String = text.substring(current, end).toLowerCase
// Remove short words and strings that aren't only letters
word match {
case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) =>
words += w
case _ =>
}
current = end
try {
end = wb.next()
} catch {
case e: Exception =>
// Ignore remaining text in line.
// This is a known bug in BreakIterator (for some Java versions),
// which fails when it sees certain characters.
end = BreakIterator.DONE
}
}
words
}
}
// scalastyle:on println