[SPARK-3142][MLLIB] output shuffle data directly in Word2Vec
Sorry I didn't realize this in #2043. Ishiihara Author: Xiangrui Meng <meng@databricks.com> Closes #2049 from mengxr/more-w2v and squashes the following commits: 050b1c5 [Xiangrui Meng] output shuffle data directly
This commit is contained in:
parent
8adfbc2b6b
commit
0a984aa155
|
@ -347,19 +347,20 @@ class Word2Vec extends Serializable with Logging {
|
||||||
}
|
}
|
||||||
val syn0Local = model._1
|
val syn0Local = model._1
|
||||||
val syn1Local = model._2
|
val syn1Local = model._2
|
||||||
val synOut = mutable.ListBuffer.empty[(Int, Array[Float])]
|
// Only output modified vectors.
|
||||||
var index = 0
|
Iterator.tabulate(vocabSize) { index =>
|
||||||
while(index < vocabSize) {
|
if (syn0Modify(index) > 0) {
|
||||||
if (syn0Modify(index) != 0) {
|
Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
|
||||||
synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
|
} else {
|
||||||
|
None
|
||||||
}
|
}
|
||||||
if (syn1Modify(index) != 0) {
|
}.flatten ++ Iterator.tabulate(vocabSize) { index =>
|
||||||
synOut += ((index + vocabSize,
|
if (syn1Modify(index) > 0) {
|
||||||
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
|
Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
}
|
}
|
||||||
index += 1
|
}.flatten
|
||||||
}
|
|
||||||
synOut.toIterator
|
|
||||||
}
|
}
|
||||||
val synAgg = partial.reduceByKey { case (v1, v2) =>
|
val synAgg = partial.reduceByKey { case (v1, v2) =>
|
||||||
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
|
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
|
||||||
|
|
Loading…
Reference in a new issue