[SPARK-22289][ML] Add JSON support for Matrix parameters (LR with coefficients bound)
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-22289 add JSON encoding/decoding for Param[Matrix]. The issue was reported by Nic Eggert during saving LR model with LowerBoundsOnCoefficients. There're two ways to resolve this as I see: 1. Support save/load on LogisticRegressionParams, and also adjust the save/load in LogisticRegression and LogisticRegressionModel. 2. Directly support Matrix in Param.jsonEncode, similar to what we have done for Vector. After some discussion in jira, we prefer the fix to support Matrix as a valid Param type, for simplicity and convenience for other classes. Note that in the implementation, I added a "class" field in the JSON object to match different JSON converters when loading, which is for preciseness and future extension. ## How was this patch tested? new unit test to cover the LR case and JsonMatrixConverter Author: Yuhao Yang <yuhao.yang@intel.com> Closes #19525 from hhbyyh/lrsave.
This commit is contained in:
parent
e6dc5f2807
commit
10c27a6559
|
@ -476,6 +476,9 @@ class DenseMatrix @Since("2.0.0") (
|
|||
@Since("2.0.0")
|
||||
object DenseMatrix {
|
||||
|
||||
private[ml] def unapply(dm: DenseMatrix): Option[(Int, Int, Array[Double], Boolean)] =
|
||||
Some((dm.numRows, dm.numCols, dm.values, dm.isTransposed))
|
||||
|
||||
/**
|
||||
* Generate a `DenseMatrix` consisting of zeros.
|
||||
* @param numRows number of rows of the matrix
|
||||
|
@ -827,6 +830,10 @@ class SparseMatrix @Since("2.0.0") (
|
|||
@Since("2.0.0")
|
||||
object SparseMatrix {
|
||||
|
||||
private[ml] def unapply(
|
||||
sm: SparseMatrix): Option[(Int, Int, Array[Int], Array[Int], Array[Double], Boolean)] =
|
||||
Some((sm.numRows, sm.numCols, sm.colPtrs, sm.rowIndices, sm.values, sm.isTransposed))
|
||||
|
||||
/**
|
||||
* Generate a `SparseMatrix` from Coordinate List (COO) format. Input must be an array of
|
||||
* (i, j, value) tuples. Entries that have duplicate values of i and j are
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.spark.ml.linalg
|
||||
|
||||
import org.json4s.DefaultFormats
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render}
|
||||
|
||||
private[ml] object JsonMatrixConverter {
|
||||
|
||||
/** Unique class name for identifying JSON object encoded by this class. */
|
||||
val className = "matrix"
|
||||
|
||||
/**
|
||||
* Parses the JSON representation of a Matrix into a [[Matrix]].
|
||||
*/
|
||||
def fromJson(json: String): Matrix = {
|
||||
implicit val formats = DefaultFormats
|
||||
val jValue = parseJson(json)
|
||||
(jValue \ "type").extract[Int] match {
|
||||
case 0 => // sparse
|
||||
val numRows = (jValue \ "numRows").extract[Int]
|
||||
val numCols = (jValue \ "numCols").extract[Int]
|
||||
val colPtrs = (jValue \ "colPtrs").extract[Seq[Int]].toArray
|
||||
val rowIndices = (jValue \ "rowIndices").extract[Seq[Int]].toArray
|
||||
val values = (jValue \ "values").extract[Seq[Double]].toArray
|
||||
val isTransposed = (jValue \ "isTransposed").extract[Boolean]
|
||||
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
|
||||
case 1 => // dense
|
||||
val numRows = (jValue \ "numRows").extract[Int]
|
||||
val numCols = (jValue \ "numCols").extract[Int]
|
||||
val values = (jValue \ "values").extract[Seq[Double]].toArray
|
||||
val isTransposed = (jValue \ "isTransposed").extract[Boolean]
|
||||
new DenseMatrix(numRows, numCols, values, isTransposed)
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"Cannot parse $json into a Matrix.")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Coverts the Matrix to a JSON string.
|
||||
*/
|
||||
def toJson(m: Matrix): String = {
|
||||
m match {
|
||||
case SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) =>
|
||||
val jValue = ("class" -> className) ~
|
||||
("type" -> 0) ~
|
||||
("numRows" -> numRows) ~
|
||||
("numCols" -> numCols) ~
|
||||
("colPtrs" -> colPtrs.toSeq) ~
|
||||
("rowIndices" -> rowIndices.toSeq) ~
|
||||
("values" -> values.toSeq) ~
|
||||
("isTransposed" -> isTransposed)
|
||||
compact(render(jValue))
|
||||
case DenseMatrix(numRows, numCols, values, isTransposed) =>
|
||||
val jValue = ("class" -> className) ~
|
||||
("type" -> 1) ~
|
||||
("numRows" -> numRows) ~
|
||||
("numCols" -> numCols) ~
|
||||
("values" -> values.toSeq) ~
|
||||
("isTransposed" -> isTransposed)
|
||||
compact(render(jValue))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -28,9 +28,9 @@ import scala.collection.mutable
|
|||
import org.json4s._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.annotation.{DeveloperApi, Since}
|
||||
import org.apache.spark.ml.linalg.JsonVectorConverter
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.linalg.{JsonMatrixConverter, JsonVectorConverter, Matrix, Vector}
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
|
||||
/**
|
||||
|
@ -94,9 +94,11 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
|
|||
compact(render(JString(x)))
|
||||
case v: Vector =>
|
||||
JsonVectorConverter.toJson(v)
|
||||
case m: Matrix =>
|
||||
JsonMatrixConverter.toJson(m)
|
||||
case _ =>
|
||||
throw new NotImplementedError(
|
||||
"The default jsonEncode only supports string and vector. " +
|
||||
"The default jsonEncode only supports string, vector and matrix. " +
|
||||
s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.")
|
||||
}
|
||||
}
|
||||
|
@ -122,17 +124,35 @@ private[ml] object Param {
|
|||
|
||||
/** Decodes a param value from JSON. */
|
||||
def jsonDecode[T](json: String): T = {
|
||||
parse(json) match {
|
||||
val jValue = parse(json)
|
||||
jValue match {
|
||||
case JString(x) =>
|
||||
x.asInstanceOf[T]
|
||||
case JObject(v) =>
|
||||
val keys = v.map(_._1)
|
||||
assert(keys.contains("type") && keys.contains("values"),
|
||||
s"Expect a JSON serialized vector but cannot find fields 'type' and 'values' in $json.")
|
||||
if (keys.contains("class")) {
|
||||
implicit val formats = DefaultFormats
|
||||
val className = (jValue \ "class").extract[String]
|
||||
className match {
|
||||
case JsonMatrixConverter.className =>
|
||||
val checkFields = Array("numRows", "numCols", "values", "isTransposed", "type")
|
||||
require(checkFields.forall(keys.contains), s"Expect a JSON serialized Matrix" +
|
||||
s" but cannot find fields ${checkFields.mkString(", ")} in $json.")
|
||||
JsonMatrixConverter.fromJson(json).asInstanceOf[T]
|
||||
|
||||
case s => throw new SparkException(s"unrecognized class $s in $json")
|
||||
}
|
||||
} else {
|
||||
// "class" info in JSON was added in Spark 2.3(SPARK-22289). JSON support for Vector was
|
||||
// implemented before that and does not have "class" attribute.
|
||||
require(keys.contains("type") && keys.contains("values"), s"Expect a JSON serialized" +
|
||||
s" vector/matrix but cannot find fields 'type' and 'values' in $json.")
|
||||
JsonVectorConverter.fromJson(json).asInstanceOf[T]
|
||||
}
|
||||
|
||||
case _ =>
|
||||
throw new NotImplementedError(
|
||||
"The default jsonDecode only supports string and vector. " +
|
||||
"The default jsonDecode only supports string, vector and matrix. " +
|
||||
s"${this.getClass.getName} must override jsonDecode to support its value type.")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2767,6 +2767,17 @@ class LogisticRegressionSuite
|
|||
val lr = new LogisticRegression()
|
||||
testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings,
|
||||
LogisticRegressionSuite.allParamSettings, checkModelData)
|
||||
|
||||
// test lr with bounds on coefficients, need to set elasticNetParam to 0.
|
||||
val numFeatures = smallBinaryDataset.select("features").head().getAs[Vector](0).size
|
||||
val lowerBounds = new DenseMatrix(1, numFeatures, (1 to numFeatures).map(_ / 1000.0).toArray)
|
||||
val upperBounds = new DenseMatrix(1, numFeatures, (1 to numFeatures).map(_ * 1000.0).toArray)
|
||||
val paramSettings = Map("lowerBoundsOnCoefficients" -> lowerBounds,
|
||||
"upperBoundsOnCoefficients" -> upperBounds,
|
||||
"elasticNetParam" -> 0.0
|
||||
)
|
||||
testEstimatorAndModelReadWrite(lr, smallBinaryDataset, paramSettings,
|
||||
paramSettings, checkModelData)
|
||||
}
|
||||
|
||||
test("should support all NumericType labels and weights, and not support other types") {
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.linalg
|
||||
|
||||
import org.json4s.jackson.JsonMethods.parse
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
|
||||
class JsonMatrixConverterSuite extends SparkFunSuite {
|
||||
|
||||
test("toJson/fromJson") {
|
||||
val denseMatrices = Seq(
|
||||
Matrices.dense(0, 0, Array.empty[Double]),
|
||||
Matrices.dense(1, 1, Array(0.1)),
|
||||
new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0), true)
|
||||
)
|
||||
|
||||
val sparseMatrices = denseMatrices.map(_.toSparse) ++ Seq(
|
||||
Matrices.sparse(3, 2, Array(0, 1, 2), Array(1, 2), Array(3.1, 3.5))
|
||||
)
|
||||
|
||||
for (m <- sparseMatrices ++ denseMatrices) {
|
||||
val json = JsonMatrixConverter.toJson(m)
|
||||
parse(json) // `json` should be a valid JSON string
|
||||
val u = JsonMatrixConverter.fromJson(json)
|
||||
assert(u.getClass === m.getClass, "toJson/fromJson should preserve Matrix types.")
|
||||
assert(u === m, "toJson/fromJson should preserve Matrix values.")
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue