From b3e89802ae760c196612dd94e37eafeafd059e26 Mon Sep 17 00:00:00 2001 From: Asher Krim Date: Sun, 5 Feb 2017 16:14:07 -0800 Subject: [PATCH] [SPARK-19247][ML] Save large word2vec models ## What changes were proposed in this pull request? * save word2vec models as distributed files rather than as one large datum. Backwards compatibility with the previous save format is maintained by checking for the "wordIndex" column * migrate the fix for loading large models (SPARK-11994) to ml word2vec ## How was this patch tested? Tested loading the new and old formats locally srowen yanboliang MLnick Author: Asher Krim Closes #16607 from Krimit/saveLargeModels. --- .../apache/spark/ml/feature/Word2Vec.scala | 56 +++++++++++++++---- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 3ed08c983d..42e8a66a62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.util.{Utils, VersionUtils} /** * Params for [[Word2Vec]] and [[Word2VecModel]]. @@ -302,16 +303,36 @@ class Word2VecModel private[ml] ( @Since("1.6.0") object Word2VecModel extends MLReadable[Word2VecModel] { + private case class Data(word: String, vector: Array[Float]) + private[Word2VecModel] class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter { - private case class Data(wordIndex: Map[String, Int], wordVectors: Seq[Float]) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq) + + val wordVectors = instance.wordVectors.getVectors + val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) } val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(dataSeq) + .repartition(calculateNumberOfPartitions) + .write + .parquet(dataPath) + } + + def calculateNumberOfPartitions(): Int = { + val floatSize = 4 + val averageWordSize = 15 + // [SPARK-11994] - We want to partition the model in partitions smaller than + // spark.kryoserializer.buffer.max + val bufferSizeInBytes = Utils.byteStringAsBytes( + sc.conf.get("spark.kryoserializer.buffer.max", "64m")) + // Calculate the approximate size of the model. + // Assuming an average word size of 15 bytes, the formula is: + // (floatSize * vectorSize + 15) * numWords + val numWords = instance.wordVectors.wordIndex.size + val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords + ((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt } } @@ -320,14 +341,29 @@ object Word2VecModel extends MLReadable[Word2VecModel] { private val className = classOf[Word2VecModel].getName override def load(path: String): Word2VecModel = { + val spark = sparkSession + import spark.implicits._ + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) + val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("wordIndex", "wordVectors") - .head() - val wordIndex = data.getAs[Map[String, Int]](0) - val wordVectors = data.getAs[Seq[Float]](1).toArray - val oldModel = new feature.Word2VecModel(wordIndex, wordVectors) + + val oldModel = if (major < 2 || (major == 2 && minor < 2)) { + val data = spark.read.parquet(dataPath) + .select("wordIndex", "wordVectors") + .head() + val wordIndex = data.getAs[Map[String, Int]](0) + val wordVectors = data.getAs[Seq[Float]](1).toArray + new feature.Word2VecModel(wordIndex, wordVectors) + } else { + val wordVectorsMap = spark.read.parquet(dataPath).as[Data] + .collect() + .map(wordVector => (wordVector.word, wordVector.vector)) + .toMap + new feature.Word2VecModel(wordVectorsMap) + } + val model = new Word2VecModel(metadata.uid, oldModel) DefaultParamsReader.getAndSetParams(model, metadata) model