[SPARK-6723] [MLLIB] Model import/export for ChiSqSelector

This is a PR for Parquet-based model import/export.

* Added save/load for ChiSqSelectorModel
* Updated the test suite ChiSqSelectorSuite

Author: Jayant Shekar <jayant@user-MBPMBA-3.local>

Closes #6785 from jayantshekhar/SPARK-6723.
This commit is contained in:
Jayant Shekar 2015-10-23 08:45:13 -07:00 committed by Xiangrui Meng
parent 282a15f78e
commit 4e38defae1
2 changed files with 95 additions and 1 deletions

View file

@ -19,11 +19,18 @@ package org.apache.spark.mllib.feature
import scala.collection.mutable.ArrayBuilder
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, Row}
/**
* :: Experimental ::
@ -34,7 +41,7 @@ import org.apache.spark.rdd.RDD
@Since("1.3.0")
@Experimental
class ChiSqSelectorModel @Since("1.3.0") (
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer {
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
require(isSorted(selectedFeatures), "Array has to be sorted asc")
@ -102,6 +109,67 @@ class ChiSqSelectorModel @Since("1.3.0") (
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
}
}
@Since("1.6.0")
override def save(sc: SparkContext, path: String): Unit = {
ChiSqSelectorModel.SaveLoadV1_0.save(sc, this, path)
}
override protected def formatVersion: String = "1.0"
}
object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
@Since("1.6.0")
override def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
ChiSqSelectorModel.SaveLoadV1_0.load(sc, path)
}
private[feature]
object SaveLoadV1_0 {
private val thisFormatVersion = "1.0"
/** Model data for import/export */
case class Data(feature: Int)
private[feature]
val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
// Create Parquet data.
val dataArray = Array.tabulate(model.selectedFeatures.length) { i =>
Data(model.selectedFeatures(i))
}
sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
implicit val formats = DefaultFormats
val sqlContext = new SQLContext(sc)
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val dataFrame = sqlContext.read.parquet(Loader.dataPath(path))
val dataArray = dataFrame.select("feature")
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
val features = dataArray.map {
case Row(feature: Int) => (feature)
}.collect()
return new ChiSqSelectorModel(features)
}
}
}
/**

View file

@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
@ -63,4 +64,29 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
}.collect().toSet
assert(filteredData == preFilteredData)
}
test("model load / save") {
val model = ChiSqSelectorSuite.createModel()
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
try {
model.save(sc, path)
val sameModel = ChiSqSelectorModel.load(sc, path)
ChiSqSelectorSuite.checkEqual(model, sameModel)
} finally {
Utils.deleteRecursively(tempDir)
}
}
}
object ChiSqSelectorSuite extends SparkFunSuite {
def createModel(): ChiSqSelectorModel = {
val arr = Array(1, 2, 3, 4)
new ChiSqSelectorModel(arr)
}
def checkEqual(a: ChiSqSelectorModel, b: ChiSqSelectorModel): Unit = {
assert(a.selectedFeatures.deep == b.selectedFeatures.deep)
}
}