[SPARK-23377][ML] Fixes Bucketizer with multiple columns persistence bug
## What changes were proposed in this pull request? #### Problem: Since 2.3, `Bucketizer` supports multiple input/output columns. We will check if exclusive params are set during transformation. E.g., if `inputCols` and `outputCol` are both set, an error will be thrown. However, when we write `Bucketizer`, looks like the default params and user-supplied params are merged during writing. All saved params are loaded back and set to created model instance. So the default `outputCol` param in `HasOutputCol` trait will be set in `paramMap` and become an user-supplied param. That makes the check of exclusive params failed. #### Fix: This changes the saving logic of Bucketizer to handle this case. This is a quick fix to catch the time of 2.3. We should consider modify the persistence mechanism later. Please see the discussion in the JIRA. Note: The multi-column `QuantileDiscretizer` also has the same issue. ## How was this patch tested? Modified tests. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #20594 from viirya/SPARK-23377-2.
This commit is contained in:
parent
6968c3cfd7
commit
db45daab90
|
@ -19,6 +19,10 @@ package org.apache.spark.ml.feature
|
|||
|
||||
import java.{util => ju}
|
||||
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.JValue
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.Model
|
||||
|
@ -213,6 +217,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
|
|||
override def copy(extra: ParamMap): Bucketizer = {
|
||||
defaultCopy[Bucketizer](extra).setParent(parent)
|
||||
}
|
||||
|
||||
override def write: MLWriter = new Bucketizer.BucketizerWriter(this)
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
|
@ -290,6 +296,28 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
// SPARK-23377: The default params will be saved and loaded as user-supplied params.
|
||||
// Once `inputCols` is set, the default value of `outputCol` param causes the error
|
||||
// when checking exclusive params. As a temporary to fix it, we skip the default value
|
||||
// of `outputCol` if `inputCols` is set when saving the metadata.
|
||||
// TODO: If we modify the persistence mechanism later to better handle default params,
|
||||
// we can get rid of this.
|
||||
var paramWithoutOutputCol: Option[JValue] = None
|
||||
if (instance.isSet(instance.inputCols)) {
|
||||
val params = instance.extractParamMap().toSeq
|
||||
val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) =>
|
||||
p.name -> parse(p.jsonEncode(v))
|
||||
}.toList
|
||||
paramWithoutOutputCol = Some(render(jsonParams))
|
||||
}
|
||||
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol)
|
||||
}
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
override def load(path: String): Bucketizer = super.load(path)
|
||||
}
|
||||
|
|
|
@ -17,6 +17,10 @@
|
|||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.JValue
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml._
|
||||
|
@ -249,11 +253,35 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
|
|||
|
||||
@Since("1.6.0")
|
||||
override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
|
||||
|
||||
override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this)
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
|
||||
|
||||
private[QuantileDiscretizer]
|
||||
class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
// SPARK-23377: The default params will be saved and loaded as user-supplied params.
|
||||
// Once `inputCols` is set, the default value of `outputCol` param causes the error
|
||||
// when checking exclusive params. As a temporary to fix it, we skip the default value
|
||||
// of `outputCol` if `inputCols` is set when saving the metadata.
|
||||
// TODO: If we modify the persistence mechanism later to better handle default params,
|
||||
// we can get rid of this.
|
||||
var paramWithoutOutputCol: Option[JValue] = None
|
||||
if (instance.isSet(instance.inputCols)) {
|
||||
val params = instance.extractParamMap().toSeq
|
||||
val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) =>
|
||||
p.name -> parse(p.jsonEncode(v))
|
||||
}.toList
|
||||
paramWithoutOutputCol = Some(render(jsonParams))
|
||||
}
|
||||
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol)
|
||||
}
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
override def load(path: String): QuantileDiscretizer = super.load(path)
|
||||
}
|
||||
|
|
|
@ -172,7 +172,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
|||
.setInputCol("myInputCol")
|
||||
.setOutputCol("myOutputCol")
|
||||
.setSplits(Array(0.1, 0.8, 0.9))
|
||||
testDefaultReadWrite(t)
|
||||
|
||||
val bucketizer = testDefaultReadWrite(t)
|
||||
val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2")
|
||||
bucketizer.transform(data)
|
||||
}
|
||||
|
||||
test("Bucket numeric features") {
|
||||
|
@ -327,7 +330,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
|||
.setInputCols(Array("myInputCol"))
|
||||
.setOutputCols(Array("myOutputCol"))
|
||||
.setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
|
||||
testDefaultReadWrite(t)
|
||||
|
||||
val bucketizer = testDefaultReadWrite(t)
|
||||
val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2")
|
||||
bucketizer.transform(data)
|
||||
assert(t.hasDefault(t.outputCol))
|
||||
assert(bucketizer.hasDefault(bucketizer.outputCol))
|
||||
}
|
||||
|
||||
test("Bucketizer in a pipeline") {
|
||||
|
|
|
@ -27,6 +27,8 @@ import org.apache.spark.sql.functions.udf
|
|||
class QuantileDiscretizerSuite
|
||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
test("Test observed number of buckets and their sizes match expected values") {
|
||||
val spark = this.spark
|
||||
import spark.implicits._
|
||||
|
@ -132,7 +134,10 @@ class QuantileDiscretizerSuite
|
|||
.setInputCol("myInputCol")
|
||||
.setOutputCol("myOutputCol")
|
||||
.setNumBuckets(6)
|
||||
testDefaultReadWrite(t)
|
||||
|
||||
val readDiscretizer = testDefaultReadWrite(t)
|
||||
val data = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("myInputCol")
|
||||
readDiscretizer.fit(data)
|
||||
}
|
||||
|
||||
test("Verify resulting model has parent") {
|
||||
|
@ -379,7 +384,12 @@ class QuantileDiscretizerSuite
|
|||
.setInputCols(Array("input1", "input2"))
|
||||
.setOutputCols(Array("result1", "result2"))
|
||||
.setNumBucketsArray(Array(5, 10))
|
||||
testDefaultReadWrite(discretizer)
|
||||
|
||||
val readDiscretizer = testDefaultReadWrite(discretizer)
|
||||
val data = Seq((1.0, 2.0), (2.0, 3.0), (3.0, 4.0)).toDF("input1", "input2")
|
||||
readDiscretizer.fit(data)
|
||||
assert(discretizer.hasDefault(discretizer.outputCol))
|
||||
assert(readDiscretizer.hasDefault(readDiscretizer.outputCol))
|
||||
}
|
||||
|
||||
test("Multiple Columns: Both inputCol and inputCols are set") {
|
||||
|
|
Loading…
Reference in a new issue