[SPARK-19247][ML] Save large word2vec models
## What changes were proposed in this pull request? * save word2vec models as distributed files rather than as one large datum. Backwards compatibility with the previous save format is maintained by checking for the "wordIndex" column * migrate the fix for loading large models (SPARK-11994) to ml word2vec ## How was this patch tested? Tested loading the new and old formats locally srowen yanboliang MLnick Author: Asher Krim <akrim@hubspot.com> Closes #16607 from Krimit/saveLargeModels.
This commit is contained in:
parent
b94f4b6fa6
commit
b3e89802ae
|
@ -30,6 +30,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._
|
|||
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.{Utils, VersionUtils}
|
||||
|
||||
/**
|
||||
* Params for [[Word2Vec]] and [[Word2VecModel]].
|
||||
|
@ -302,16 +303,36 @@ class Word2VecModel private[ml] (
|
|||
@Since("1.6.0")
|
||||
object Word2VecModel extends MLReadable[Word2VecModel] {
|
||||
|
||||
private case class Data(word: String, vector: Array[Float])
|
||||
|
||||
private[Word2VecModel]
|
||||
class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter {
|
||||
|
||||
private case class Data(wordIndex: Map[String, Int], wordVectors: Seq[Float])
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
DefaultParamsWriter.saveMetadata(instance, path, sc)
|
||||
val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq)
|
||||
|
||||
val wordVectors = instance.wordVectors.getVectors
|
||||
val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) }
|
||||
val dataPath = new Path(path, "data").toString
|
||||
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
|
||||
sparkSession.createDataFrame(dataSeq)
|
||||
.repartition(calculateNumberOfPartitions)
|
||||
.write
|
||||
.parquet(dataPath)
|
||||
}
|
||||
|
||||
def calculateNumberOfPartitions(): Int = {
|
||||
val floatSize = 4
|
||||
val averageWordSize = 15
|
||||
// [SPARK-11994] - We want to partition the model in partitions smaller than
|
||||
// spark.kryoserializer.buffer.max
|
||||
val bufferSizeInBytes = Utils.byteStringAsBytes(
|
||||
sc.conf.get("spark.kryoserializer.buffer.max", "64m"))
|
||||
// Calculate the approximate size of the model.
|
||||
// Assuming an average word size of 15 bytes, the formula is:
|
||||
// (floatSize * vectorSize + 15) * numWords
|
||||
val numWords = instance.wordVectors.wordIndex.size
|
||||
val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords
|
||||
((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -320,14 +341,29 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
|
|||
private val className = classOf[Word2VecModel].getName
|
||||
|
||||
override def load(path: String): Word2VecModel = {
|
||||
val spark = sparkSession
|
||||
import spark.implicits._
|
||||
|
||||
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
|
||||
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
|
||||
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val data = sparkSession.read.parquet(dataPath)
|
||||
.select("wordIndex", "wordVectors")
|
||||
.head()
|
||||
val wordIndex = data.getAs[Map[String, Int]](0)
|
||||
val wordVectors = data.getAs[Seq[Float]](1).toArray
|
||||
val oldModel = new feature.Word2VecModel(wordIndex, wordVectors)
|
||||
|
||||
val oldModel = if (major < 2 || (major == 2 && minor < 2)) {
|
||||
val data = spark.read.parquet(dataPath)
|
||||
.select("wordIndex", "wordVectors")
|
||||
.head()
|
||||
val wordIndex = data.getAs[Map[String, Int]](0)
|
||||
val wordVectors = data.getAs[Seq[Float]](1).toArray
|
||||
new feature.Word2VecModel(wordIndex, wordVectors)
|
||||
} else {
|
||||
val wordVectorsMap = spark.read.parquet(dataPath).as[Data]
|
||||
.collect()
|
||||
.map(wordVector => (wordVector.word, wordVector.vector))
|
||||
.toMap
|
||||
new feature.Word2VecModel(wordVectorsMap)
|
||||
}
|
||||
|
||||
val model = new Word2VecModel(metadata.uid, oldModel)
|
||||
DefaultParamsReader.getAndSetParams(model, metadata)
|
||||
model
|
||||
|
|
Loading…
Reference in a new issue