[SPARK-11766][MLLIB] add toJson/fromJson to Vector/Vectors
This is to support JSON serialization of Param[Vector] in the pipeline API. It could be used for other purposes too. The schema is the same as `VectorUDT`. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9751 from mengxr/SPARK-11766.
This commit is contained in:
parent
cc567b6634
commit
21fac54341
|
@ -24,6 +24,9 @@ import scala.annotation.varargs
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
|
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
|
||||||
|
import org.json4s.DefaultFormats
|
||||||
|
import org.json4s.JsonDSL._
|
||||||
|
import org.json4s.jackson.JsonMethods.{compact, render, parse => parseJson}
|
||||||
|
|
||||||
import org.apache.spark.SparkException
|
import org.apache.spark.SparkException
|
||||||
import org.apache.spark.annotation.{AlphaComponent, Since}
|
import org.apache.spark.annotation.{AlphaComponent, Since}
|
||||||
|
@ -171,6 +174,12 @@ sealed trait Vector extends Serializable {
|
||||||
*/
|
*/
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
def argmax: Int
|
def argmax: Int
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the vector to a JSON string.
|
||||||
|
*/
|
||||||
|
@Since("1.6.0")
|
||||||
|
def toJson: String
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -339,6 +348,27 @@ object Vectors {
|
||||||
parseNumeric(NumericParser.parse(s))
|
parseNumeric(NumericParser.parse(s))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses the JSON representation of a vector into a [[Vector]].
|
||||||
|
*/
|
||||||
|
@Since("1.6.0")
|
||||||
|
def fromJson(json: String): Vector = {
|
||||||
|
implicit val formats = DefaultFormats
|
||||||
|
val jValue = parseJson(json)
|
||||||
|
(jValue \ "type").extract[Int] match {
|
||||||
|
case 0 => // sparse
|
||||||
|
val size = (jValue \ "size").extract[Int]
|
||||||
|
val indices = (jValue \ "indices").extract[Seq[Int]].toArray
|
||||||
|
val values = (jValue \ "values").extract[Seq[Double]].toArray
|
||||||
|
sparse(size, indices, values)
|
||||||
|
case 1 => // dense
|
||||||
|
val values = (jValue \ "values").extract[Seq[Double]].toArray
|
||||||
|
dense(values)
|
||||||
|
case _ =>
|
||||||
|
throw new IllegalArgumentException(s"Cannot parse $json into a vector.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private[mllib] def parseNumeric(any: Any): Vector = {
|
private[mllib] def parseNumeric(any: Any): Vector = {
|
||||||
any match {
|
any match {
|
||||||
case values: Array[Double] =>
|
case values: Array[Double] =>
|
||||||
|
@ -650,6 +680,12 @@ class DenseVector @Since("1.0.0") (
|
||||||
maxIdx
|
maxIdx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Since("1.6.0")
|
||||||
|
override def toJson: String = {
|
||||||
|
val jValue = ("type" -> 1) ~ ("values" -> values.toSeq)
|
||||||
|
compact(render(jValue))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.3.0")
|
@Since("1.3.0")
|
||||||
|
@ -837,6 +873,15 @@ class SparseVector @Since("1.0.0") (
|
||||||
}.unzip
|
}.unzip
|
||||||
new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
|
new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Since("1.6.0")
|
||||||
|
override def toJson: String = {
|
||||||
|
val jValue = ("type" -> 0) ~
|
||||||
|
("size" -> size) ~
|
||||||
|
("indices" -> indices.toSeq) ~
|
||||||
|
("values" -> values.toSeq)
|
||||||
|
compact(render(jValue))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.3.0")
|
@Since("1.3.0")
|
||||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.mllib.linalg
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
|
import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
|
||||||
|
import org.json4s.jackson.JsonMethods.{parse => parseJson}
|
||||||
|
|
||||||
import org.apache.spark.{Logging, SparkException, SparkFunSuite}
|
import org.apache.spark.{Logging, SparkException, SparkFunSuite}
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
|
@ -374,4 +375,20 @@ class VectorsSuite extends SparkFunSuite with Logging {
|
||||||
assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
|
assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
|
||||||
assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
|
assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("toJson/fromJson") {
|
||||||
|
val sv0 = Vectors.sparse(0, Array.empty, Array.empty)
|
||||||
|
val sv1 = Vectors.sparse(1, Array.empty, Array.empty)
|
||||||
|
val sv2 = Vectors.sparse(2, Array(1), Array(2.0))
|
||||||
|
val dv0 = Vectors.dense(Array.empty[Double])
|
||||||
|
val dv1 = Vectors.dense(1.0)
|
||||||
|
val dv2 = Vectors.dense(0.0, 2.0)
|
||||||
|
for (v <- Seq(sv0, sv1, sv2, dv0, dv1, dv2)) {
|
||||||
|
val json = v.toJson
|
||||||
|
parseJson(json) // `json` should be a valid JSON string
|
||||||
|
val u = Vectors.fromJson(json)
|
||||||
|
assert(u.getClass === v.getClass, "toJson/fromJson should preserve vector types.")
|
||||||
|
assert(u === v, "toJson/fromJson should preserve vector values.")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -137,6 +137,10 @@ object MimaExcludes {
|
||||||
) ++ Seq (
|
) ++ Seq (
|
||||||
ProblemFilters.exclude[MissingMethodProblem](
|
ProblemFilters.exclude[MissingMethodProblem](
|
||||||
"org.apache.spark.status.api.v1.ApplicationInfo.this")
|
"org.apache.spark.status.api.v1.ApplicationInfo.this")
|
||||||
|
) ++ Seq(
|
||||||
|
// SPARK-11766 add toJson to Vector
|
||||||
|
ProblemFilters.exclude[MissingMethodProblem](
|
||||||
|
"org.apache.spark.mllib.linalg.Vector.toJson")
|
||||||
)
|
)
|
||||||
case v if v.startsWith("1.5") =>
|
case v if v.startsWith("1.5") =>
|
||||||
Seq(
|
Seq(
|
||||||
|
|
Loading…
Reference in a new issue