From 4e38defae13b2b13e196b4d172722ef5e6266c66 Mon Sep 17 00:00:00 2001 From: Jayant Shekar Date: Fri, 23 Oct 2015 08:45:13 -0700 Subject: [PATCH] [SPARK-6723] [MLLIB] Model import/export for ChiSqSelector This is a PR for Parquet-based model import/export. * Added save/load for ChiSqSelectorModel * Updated the test suite ChiSqSelectorSuite Author: Jayant Shekar Closes #6785 from jayantshekhar/SPARK-6723. --- .../spark/mllib/feature/ChiSqSelector.scala | 70 ++++++++++++++++++- .../mllib/feature/ChiSqSelectorSuite.scala | 26 +++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index b1524cf377..5246faf221 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -19,11 +19,18 @@ package org.apache.spark.mllib.feature import scala.collection.mutable.ArrayBuilder +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: @@ -34,7 +41,7 @@ import org.apache.spark.rdd.RDD @Since("1.3.0") @Experimental class ChiSqSelectorModel @Since("1.3.0") ( - @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer { + @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { require(isSorted(selectedFeatures), "Array has to be sorted asc") @@ -102,6 +109,67 @@ class ChiSqSelectorModel @Since("1.3.0") ( s"Only sparse and dense vectors are supported but got ${other.getClass}.") } } + + @Since("1.6.0") + override def save(sc: SparkContext, path: String): Unit = { + ChiSqSelectorModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { + @Since("1.6.0") + override def load(sc: SparkContext, path: String): ChiSqSelectorModel = { + ChiSqSelectorModel.SaveLoadV1_0.load(sc, path) + } + + private[feature] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + /** Model data for import/export */ + case class Data(feature: Int) + + private[feature] + val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" + + def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataArray = Array.tabulate(model.selectedFeatures.length) { i => + Data(model.selectedFeatures(i)) + } + sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path)) + + } + + def load(sc: SparkContext, path: String): ChiSqSelectorModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val dataFrame = sqlContext.read.parquet(Loader.dataPath(path)) + val dataArray = dataFrame.select("feature") + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val features = dataArray.map { + case Row(feature: Int) => (feature) + }.collect() + + return new ChiSqSelectorModel(features) + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 889727fb55..734800a9af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -63,4 +64,29 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { }.collect().toSet assert(filteredData == preFilteredData) } + + test("model load / save") { + val model = ChiSqSelectorSuite.createModel() + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model.save(sc, path) + val sameModel = ChiSqSelectorModel.load(sc, path) + ChiSqSelectorSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} + +object ChiSqSelectorSuite extends SparkFunSuite { + + def createModel(): ChiSqSelectorModel = { + val arr = Array(1, 2, 3, 4) + new ChiSqSelectorModel(arr) + } + + def checkEqual(a: ChiSqSelectorModel, b: ChiSqSelectorModel): Unit = { + assert(a.selectedFeatures.deep == b.selectedFeatures.deep) + } }