[SPARK-6615][MLLIB] Python API for Word2Vec

This is the sub-task of SPARK-6254.
Wrap missing method for `Word2Vec` and `Word2VecModel`.

Author: lewuathe <lewuathe@me.com>

Closes #5296 from Lewuathe/SPARK-6615 and squashes the following commits:

f14c304 [lewuathe] Reorder tests
1d326b9 [lewuathe] Merge master
e2bedfb [lewuathe] Modify test cases
afb866d [lewuathe] [SPARK-6615] Python API for Word2Vec
This commit is contained in:
lewuathe 2015-04-03 09:49:50 -07:00 committed by Xiangrui Meng
parent b52c7f9fc8
commit 512a2f191a
3 changed files with 63 additions and 6 deletions

View file

@ -476,13 +476,15 @@ private[python] class PythonMLLibAPI extends Serializable {
learningRate: Double,
numPartitions: Int,
numIterations: Int,
seed: Long): Word2VecModelWrapper = {
seed: Long,
minCount: Int): Word2VecModelWrapper = {
val word2vec = new Word2Vec()
.setVectorSize(vectorSize)
.setLearningRate(learningRate)
.setNumPartitions(numPartitions)
.setNumIterations(numIterations)
.setSeed(seed)
.setMinCount(minCount)
try {
val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
new Word2VecModelWrapper(model)
@ -516,6 +518,10 @@ private[python] class PythonMLLibAPI extends Serializable {
val words = result.map(_._1)
List(words, similarity).map(_.asInstanceOf[Object]).asJava
}
def getVectors: JMap[String, JList[Float]] = {
model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
}
}
/**

View file

@ -337,6 +337,12 @@ class Word2VecModel(JavaVectorTransformer):
words, similarity = self.call("findSynonyms", word, num)
return zip(words, similarity)
def getVectors(self):
"""
Returns a map of words to their vector representations.
"""
return self.call("getVectors")
class Word2Vec(object):
"""
@ -379,6 +385,7 @@ class Word2Vec(object):
self.numPartitions = 1
self.numIterations = 1
self.seed = random.randint(0, sys.maxint)
self.minCount = 5
def setVectorSize(self, vectorSize):
"""
@ -417,6 +424,14 @@ class Word2Vec(object):
self.seed = seed
return self
def setMinCount(self, minCount):
"""
Sets minCount, the minimum number of times a token must appear
to be included in the word2vec model's vocabulary (default: 5).
"""
self.minCount = minCount
return self
def fit(self, data):
"""
Computes the vector representation of each word in vocabulary.
@ -428,7 +443,8 @@ class Word2Vec(object):
raise TypeError("data should be an RDD of list of string")
jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
float(self.learningRate), int(self.numPartitions),
int(self.numIterations), long(self.seed))
int(self.numIterations), long(self.seed),
int(self.minCount))
return Word2VecModel(jmodel)

View file

@ -42,6 +42,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
from pyspark.serializers import PickleSerializer
from pyspark.sql import SQLContext
@ -630,6 +631,12 @@ class ChiSqTestTests(PySparkTestCase):
self.assertIsNotNone(chi[1000])
class SerDeTest(PySparkTestCase):
def test_to_java_object_rdd(self): # SPARK-6660
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
self.assertEqual(_to_java_object_rdd(data).count(), 10)
class FeatureTest(PySparkTestCase):
def test_idf_model(self):
data = [
@ -643,11 +650,39 @@ class FeatureTest(PySparkTestCase):
self.assertEqual(len(idf), 11)
class SerDeTest(PySparkTestCase):
def test_to_java_object_rdd(self): # SPARK-6660
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
self.assertEqual(_to_java_object_rdd(data).count(), 10)
class Word2VecTests(PySparkTestCase):
def test_word2vec_setters(self):
data = [
["I", "have", "a", "pen"],
["I", "like", "soccer", "very", "much"],
["I", "live", "in", "Tokyo"]
]
model = Word2Vec() \
.setVectorSize(2) \
.setLearningRate(0.01) \
.setNumPartitions(2) \
.setNumIterations(10) \
.setSeed(1024) \
.setMinCount(3)
self.assertEquals(model.vectorSize, 2)
self.assertTrue(model.learningRate < 0.02)
self.assertEquals(model.numPartitions, 2)
self.assertEquals(model.numIterations, 10)
self.assertEquals(model.seed, 1024)
self.assertEquals(model.minCount, 3)
def test_word2vec_get_vectors(self):
data = [
["a", "b", "c", "d", "e", "f", "g"],
["a", "b", "c", "d", "e", "f"],
["a", "b", "c", "d", "e"],
["a", "b", "c", "d"],
["a", "b", "c"],
["a", "b"],
["a"]
]
model = Word2Vec().fit(self.sc.parallelize(data))
self.assertEquals(len(model.getVectors()), 3)
if __name__ == "__main__":
if not _have_scipy: