[SPARK-6723] [MLLIB] Model import/export for ChiSqSelector
This is a PR for Parquet-based model import/export. * Added save/load for ChiSqSelectorModel * Updated the test suite ChiSqSelectorSuite Author: Jayant Shekar <jayant@user-MBPMBA-3.local> Closes #6785 from jayantshekhar/SPARK-6723.
This commit is contained in:
parent
282a15f78e
commit
4e38defae1
|
@ -19,11 +19,18 @@ package org.apache.spark.mllib.feature
|
|||
|
||||
import scala.collection.mutable.ArrayBuilder
|
||||
|
||||
import org.json4s._
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.annotation.{Experimental, Since}
|
||||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.stat.Statistics
|
||||
import org.apache.spark.mllib.util.{Loader, Saveable}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql.{SQLContext, Row}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
|
@ -34,7 +41,7 @@ import org.apache.spark.rdd.RDD
|
|||
@Since("1.3.0")
|
||||
@Experimental
|
||||
class ChiSqSelectorModel @Since("1.3.0") (
|
||||
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer {
|
||||
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
|
||||
|
||||
require(isSorted(selectedFeatures), "Array has to be sorted asc")
|
||||
|
||||
|
@ -102,6 +109,67 @@ class ChiSqSelectorModel @Since("1.3.0") (
|
|||
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
|
||||
}
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
override def save(sc: SparkContext, path: String): Unit = {
|
||||
ChiSqSelectorModel.SaveLoadV1_0.save(sc, this, path)
|
||||
}
|
||||
|
||||
override protected def formatVersion: String = "1.0"
|
||||
}
|
||||
|
||||
object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
|
||||
@Since("1.6.0")
|
||||
override def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
|
||||
ChiSqSelectorModel.SaveLoadV1_0.load(sc, path)
|
||||
}
|
||||
|
||||
private[feature]
|
||||
object SaveLoadV1_0 {
|
||||
|
||||
private val thisFormatVersion = "1.0"
|
||||
|
||||
/** Model data for import/export */
|
||||
case class Data(feature: Int)
|
||||
|
||||
private[feature]
|
||||
val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
|
||||
|
||||
def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
import sqlContext.implicits._
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
|
||||
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
|
||||
|
||||
// Create Parquet data.
|
||||
val dataArray = Array.tabulate(model.selectedFeatures.length) { i =>
|
||||
Data(model.selectedFeatures(i))
|
||||
}
|
||||
sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path))
|
||||
|
||||
}
|
||||
|
||||
def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
|
||||
implicit val formats = DefaultFormats
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
assert(formatVersion == thisFormatVersion)
|
||||
|
||||
val dataFrame = sqlContext.read.parquet(Loader.dataPath(path))
|
||||
val dataArray = dataFrame.select("feature")
|
||||
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
Loader.checkSchema[Data](dataFrame.schema)
|
||||
|
||||
val features = dataArray.map {
|
||||
case Row(feature: Int) => (feature)
|
||||
}.collect()
|
||||
|
||||
return new ChiSqSelectorModel(features)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
|
|||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
|
@ -63,4 +64,29 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
}.collect().toSet
|
||||
assert(filteredData == preFilteredData)
|
||||
}
|
||||
|
||||
test("model load / save") {
|
||||
val model = ChiSqSelectorSuite.createModel()
|
||||
val tempDir = Utils.createTempDir()
|
||||
val path = tempDir.toURI.toString
|
||||
try {
|
||||
model.save(sc, path)
|
||||
val sameModel = ChiSqSelectorModel.load(sc, path)
|
||||
ChiSqSelectorSuite.checkEqual(model, sameModel)
|
||||
} finally {
|
||||
Utils.deleteRecursively(tempDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object ChiSqSelectorSuite extends SparkFunSuite {
|
||||
|
||||
def createModel(): ChiSqSelectorModel = {
|
||||
val arr = Array(1, 2, 3, 4)
|
||||
new ChiSqSelectorModel(arr)
|
||||
}
|
||||
|
||||
def checkEqual(a: ChiSqSelectorModel, b: ChiSqSelectorModel): Unit = {
|
||||
assert(a.selectedFeatures.deep == b.selectedFeatures.deep)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue