From 3e9e6380236985ec5b51b459f8c61f964a76ff8b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 17 Nov 2015 14:04:49 -0800 Subject: [PATCH] [SPARK-11764][ML] make Param.jsonEncode/jsonDecode support Vector This PR makes the default read/write work with simple transformers/estimators that have params of type `Param[Vector]`. jkbradley Author: Xiangrui Meng Closes #9776 from mengxr/SPARK-11764. --- .../org/apache/spark/ml/param/params.scala | 12 ++++++++-- .../apache/spark/ml/param/ParamsSuite.scala | 22 +++++++++++++++---- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index c932570918..d182b0a988 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -29,6 +29,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: @@ -88,9 +89,11 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali value match { case x: String => compact(render(JString(x))) + case v: Vector => + v.toJson case _ => throw new NotImplementedError( - "The default jsonEncode only supports string. " + + "The default jsonEncode only supports string and vector. " + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") } } @@ -100,9 +103,14 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali parse(json) match { case JString(x) => x.asInstanceOf[T] + case JObject(v) => + val keys = v.map(_._1) + assert(keys.contains("type") && keys.contains("values"), + s"Expect a JSON serialized vector but cannot find fields 'type' and 'values' in $json.") + Vectors.fromJson(json).asInstanceOf[T] case _ => throw new NotImplementedError( - "The default jsonDecode only supports string. " + + "The default jsonDecode only supports string and vector. " + s"${this.getClass.getName} must override jsonDecode to support its value type.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index eeb03dba2f..a1878be747 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.param import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} class ParamsSuite extends SparkFunSuite { @@ -80,7 +81,7 @@ class ParamsSuite extends SparkFunSuite { } } - { // StringParam + { // Param[String] val param = new Param[String](dummy, "name", "doc") // Currently we do not support null. for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) { @@ -89,6 +90,19 @@ class ParamsSuite extends SparkFunSuite { } } + { // Param[Vector] + val param = new Param[Vector](dummy, "name", "doc") + val values = Seq( + Vectors.dense(Array.empty[Double]), + Vectors.dense(0.0, 2.0), + Vectors.sparse(0, Array.empty, Array.empty), + Vectors.sparse(2, Array(1), Array(2.0))) + for (value <- values) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + { // IntArrayParam val param = new IntArrayParam(dummy, "name", "doc") val values: Seq[Array[Int]] = Seq( @@ -138,7 +152,7 @@ class ParamsSuite extends SparkFunSuite { test("param") { val solver = new TestParams() val uid = solver.uid - import solver.{maxIter, inputCol} + import solver.{inputCol, maxIter} assert(maxIter.name === "maxIter") assert(maxIter.doc === "maximum number of iterations (>= 0)") @@ -181,7 +195,7 @@ class ParamsSuite extends SparkFunSuite { test("param map") { val solver = new TestParams() - import solver.{maxIter, inputCol} + import solver.{inputCol, maxIter} val map0 = ParamMap.empty @@ -220,7 +234,7 @@ class ParamsSuite extends SparkFunSuite { test("params") { val solver = new TestParams() - import solver.{handleInvalid, maxIter, inputCol} + import solver.{handleInvalid, inputCol, maxIter} val params = solver.params assert(params.length === 3)