[SPARK-6615][MLLIB] Python API for Word2Vec
This is the sub-task of SPARK-6254. Wrap missing method for `Word2Vec` and `Word2VecModel`. Author: lewuathe <lewuathe@me.com> Closes #5296 from Lewuathe/SPARK-6615 and squashes the following commits: f14c304 [lewuathe] Reorder tests 1d326b9 [lewuathe] Merge master e2bedfb [lewuathe] Modify test cases afb866d [lewuathe] [SPARK-6615] Python API for Word2Vec
This commit is contained in:
parent
b52c7f9fc8
commit
512a2f191a
|
@ -476,13 +476,15 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
learningRate: Double,
|
||||
numPartitions: Int,
|
||||
numIterations: Int,
|
||||
seed: Long): Word2VecModelWrapper = {
|
||||
seed: Long,
|
||||
minCount: Int): Word2VecModelWrapper = {
|
||||
val word2vec = new Word2Vec()
|
||||
.setVectorSize(vectorSize)
|
||||
.setLearningRate(learningRate)
|
||||
.setNumPartitions(numPartitions)
|
||||
.setNumIterations(numIterations)
|
||||
.setSeed(seed)
|
||||
.setMinCount(minCount)
|
||||
try {
|
||||
val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
|
||||
new Word2VecModelWrapper(model)
|
||||
|
@ -516,6 +518,10 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
val words = result.map(_._1)
|
||||
List(words, similarity).map(_.asInstanceOf[Object]).asJava
|
||||
}
|
||||
|
||||
def getVectors: JMap[String, JList[Float]] = {
|
||||
model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -337,6 +337,12 @@ class Word2VecModel(JavaVectorTransformer):
|
|||
words, similarity = self.call("findSynonyms", word, num)
|
||||
return zip(words, similarity)
|
||||
|
||||
def getVectors(self):
|
||||
"""
|
||||
Returns a map of words to their vector representations.
|
||||
"""
|
||||
return self.call("getVectors")
|
||||
|
||||
|
||||
class Word2Vec(object):
|
||||
"""
|
||||
|
@ -379,6 +385,7 @@ class Word2Vec(object):
|
|||
self.numPartitions = 1
|
||||
self.numIterations = 1
|
||||
self.seed = random.randint(0, sys.maxint)
|
||||
self.minCount = 5
|
||||
|
||||
def setVectorSize(self, vectorSize):
|
||||
"""
|
||||
|
@ -417,6 +424,14 @@ class Word2Vec(object):
|
|||
self.seed = seed
|
||||
return self
|
||||
|
||||
def setMinCount(self, minCount):
|
||||
"""
|
||||
Sets minCount, the minimum number of times a token must appear
|
||||
to be included in the word2vec model's vocabulary (default: 5).
|
||||
"""
|
||||
self.minCount = minCount
|
||||
return self
|
||||
|
||||
def fit(self, data):
|
||||
"""
|
||||
Computes the vector representation of each word in vocabulary.
|
||||
|
@ -428,7 +443,8 @@ class Word2Vec(object):
|
|||
raise TypeError("data should be an RDD of list of string")
|
||||
jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
|
||||
float(self.learningRate), int(self.numPartitions),
|
||||
int(self.numIterations), long(self.seed))
|
||||
int(self.numIterations), long(self.seed),
|
||||
int(self.minCount))
|
||||
return Word2VecModel(jmodel)
|
||||
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _
|
|||
from pyspark.mllib.regression import LabeledPoint
|
||||
from pyspark.mllib.random import RandomRDDs
|
||||
from pyspark.mllib.stat import Statistics
|
||||
from pyspark.mllib.feature import Word2Vec
|
||||
from pyspark.mllib.feature import IDF
|
||||
from pyspark.serializers import PickleSerializer
|
||||
from pyspark.sql import SQLContext
|
||||
|
@ -630,6 +631,12 @@ class ChiSqTestTests(PySparkTestCase):
|
|||
self.assertIsNotNone(chi[1000])
|
||||
|
||||
|
||||
class SerDeTest(PySparkTestCase):
|
||||
def test_to_java_object_rdd(self): # SPARK-6660
|
||||
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
|
||||
self.assertEqual(_to_java_object_rdd(data).count(), 10)
|
||||
|
||||
|
||||
class FeatureTest(PySparkTestCase):
|
||||
def test_idf_model(self):
|
||||
data = [
|
||||
|
@ -643,11 +650,39 @@ class FeatureTest(PySparkTestCase):
|
|||
self.assertEqual(len(idf), 11)
|
||||
|
||||
|
||||
class SerDeTest(PySparkTestCase):
|
||||
def test_to_java_object_rdd(self): # SPARK-6660
|
||||
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
|
||||
self.assertEqual(_to_java_object_rdd(data).count(), 10)
|
||||
class Word2VecTests(PySparkTestCase):
|
||||
def test_word2vec_setters(self):
|
||||
data = [
|
||||
["I", "have", "a", "pen"],
|
||||
["I", "like", "soccer", "very", "much"],
|
||||
["I", "live", "in", "Tokyo"]
|
||||
]
|
||||
model = Word2Vec() \
|
||||
.setVectorSize(2) \
|
||||
.setLearningRate(0.01) \
|
||||
.setNumPartitions(2) \
|
||||
.setNumIterations(10) \
|
||||
.setSeed(1024) \
|
||||
.setMinCount(3)
|
||||
self.assertEquals(model.vectorSize, 2)
|
||||
self.assertTrue(model.learningRate < 0.02)
|
||||
self.assertEquals(model.numPartitions, 2)
|
||||
self.assertEquals(model.numIterations, 10)
|
||||
self.assertEquals(model.seed, 1024)
|
||||
self.assertEquals(model.minCount, 3)
|
||||
|
||||
def test_word2vec_get_vectors(self):
|
||||
data = [
|
||||
["a", "b", "c", "d", "e", "f", "g"],
|
||||
["a", "b", "c", "d", "e", "f"],
|
||||
["a", "b", "c", "d", "e"],
|
||||
["a", "b", "c", "d"],
|
||||
["a", "b", "c"],
|
||||
["a", "b"],
|
||||
["a"]
|
||||
]
|
||||
model = Word2Vec().fit(self.sc.parallelize(data))
|
||||
self.assertEquals(len(model.getVectors()), 3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not _have_scipy:
|
||||
|
|
Loading…
Reference in a new issue