[SPARK-3142][MLLIB] output shuffle data directly in Word2Vec
Sorry I didn't realize this in #2043. Ishiihara Author: Xiangrui Meng <meng@databricks.com> Closes #2049 from mengxr/more-w2v and squashes the following commits: 050b1c5 [Xiangrui Meng] output shuffle data directly
This commit is contained in:
parent
8adfbc2b6b
commit
0a984aa155
|
@ -347,19 +347,20 @@ class Word2Vec extends Serializable with Logging {
|
|||
}
|
||||
val syn0Local = model._1
|
||||
val syn1Local = model._2
|
||||
val synOut = mutable.ListBuffer.empty[(Int, Array[Float])]
|
||||
var index = 0
|
||||
while(index < vocabSize) {
|
||||
if (syn0Modify(index) != 0) {
|
||||
synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
|
||||
// Only output modified vectors.
|
||||
Iterator.tabulate(vocabSize) { index =>
|
||||
if (syn0Modify(index) > 0) {
|
||||
Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
if (syn1Modify(index) != 0) {
|
||||
synOut += ((index + vocabSize,
|
||||
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
|
||||
}.flatten ++ Iterator.tabulate(vocabSize) { index =>
|
||||
if (syn1Modify(index) > 0) {
|
||||
Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
index += 1
|
||||
}
|
||||
synOut.toIterator
|
||||
}.flatten
|
||||
}
|
||||
val synAgg = partial.reduceByKey { case (v1, v2) =>
|
||||
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
|
||||
|
|
Loading…
Reference in a new issue