[SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit
## What changes were proposed in this pull request? 1, move cast to `Predictor` 2, and then, remove unnecessary cast ## How was this patch tested? existing tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #15414 from zhengruifeng/move_cast.
This commit is contained in:
parent
0cba535af3
commit
8ac09108fc
|
@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params
|
|||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* Abstraction for prediction problems (regression and classification).
|
||||
* Abstraction for prediction problems (regression and classification). It accepts all NumericType
|
||||
* labels and will automatically cast it to DoubleType in [[fit()]].
|
||||
*
|
||||
* @tparam FeaturesType Type of features.
|
||||
* E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
|
||||
|
@ -87,7 +88,12 @@ abstract class Predictor[
|
|||
// This handles a few items such as schema validation.
|
||||
// Developers only need to implement train().
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
copyValues(train(dataset).setParent(this))
|
||||
|
||||
// Cast LabelCol to DoubleType and keep the metadata.
|
||||
val labelMeta = dataset.schema($(labelCol)).metadata
|
||||
val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
|
||||
|
||||
copyValues(train(casted).setParent(this))
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): Learner
|
||||
|
@ -121,7 +127,7 @@ abstract class Predictor[
|
|||
* and put it in an RDD with strong types.
|
||||
*/
|
||||
protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
|
||||
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
|
||||
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
|
||||
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ abstract class Classifier[
|
|||
* and put it in an RDD with strong types.
|
||||
*
|
||||
* @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
|
||||
* and features ([[Vector]]). Labels are cast to [[DoubleType]].
|
||||
* and features ([[Vector]]).
|
||||
* @param numClasses Number of classes label can take. Labels must be integers in the range
|
||||
* [0, numClasses).
|
||||
* @throws SparkException if any label is not an integer >= 0
|
||||
|
@ -79,7 +79,7 @@ abstract class Classifier[
|
|||
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
|
||||
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
|
||||
s" $numClasses, but requires numClasses > 0.")
|
||||
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
|
||||
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
|
||||
case Row(label: Double, features: Vector) =>
|
||||
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
|
||||
s" dataset with invalid label $label. Labels must be integers in range" +
|
||||
|
|
|
@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") (
|
|||
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
|
||||
// 2 classes now. This lets us provide a more precise error message.
|
||||
val oldDataset: RDD[LabeledPoint] =
|
||||
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
|
||||
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
|
||||
case Row(label: Double, features: Vector) =>
|
||||
require(label == 0 || label == 1, s"GBTClassifier was given" +
|
||||
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
|
||||
|
|
|
@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") (
|
|||
LogisticRegressionModel = {
|
||||
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val instances: RDD[Instance] =
|
||||
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
|
||||
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
|
||||
case Row(label: Double, weight: Double, features: Vector) =>
|
||||
Instance(label, weight, features)
|
||||
}
|
||||
|
|
|
@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") (
|
|||
// Aggregates term frequencies per label.
|
||||
// TODO: Calling aggregateByKey and collect creates two stages, we can implement something
|
||||
// TODO: similar to reduceByKeyLocally to save one stage.
|
||||
val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd
|
||||
val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
|
||||
.map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
|
||||
}.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))(
|
||||
seqOp = {
|
||||
|
|
|
@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
|
|||
|
||||
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val instances: RDD[Instance] =
|
||||
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
|
||||
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
|
||||
case Row(label: Double, weight: Double, features: Vector) =>
|
||||
Instance(label, weight, features)
|
||||
}
|
||||
|
|
|
@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
|
||||
val instances: RDD[Instance] = dataset.select(
|
||||
col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
|
||||
col($(labelCol)), w, col($(featuresCol))).rdd.map {
|
||||
case Row(label: Double, weight: Double, features: Vector) =>
|
||||
Instance(label, weight, features)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
import PredictorSuite._
|
||||
|
||||
test("should support all NumericType labels and not support other types") {
|
||||
val df = spark.createDataFrame(Seq(
|
||||
(0, Vectors.dense(0, 2, 3)),
|
||||
(1, Vectors.dense(0, 3, 9)),
|
||||
(0, Vectors.dense(0, 2, 6))
|
||||
)).toDF("label", "features")
|
||||
|
||||
val types =
|
||||
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
|
||||
|
||||
val predictor = new MockPredictor()
|
||||
|
||||
types.foreach { t =>
|
||||
predictor.fit(df.select(col("label").cast(t), col("features")))
|
||||
}
|
||||
|
||||
intercept[IllegalArgumentException] {
|
||||
predictor.fit(df.select(col("label").cast(StringType), col("features")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object PredictorSuite {
|
||||
|
||||
class MockPredictor(override val uid: String)
|
||||
extends Predictor[Vector, MockPredictor, MockPredictionModel] {
|
||||
|
||||
def this() = this(Identifiable.randomUID("mockpredictor"))
|
||||
|
||||
override def train(dataset: Dataset[_]): MockPredictionModel = {
|
||||
require(dataset.schema("label").dataType == DoubleType)
|
||||
new MockPredictionModel(uid)
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): MockPredictor =
|
||||
throw new NotImplementedError()
|
||||
}
|
||||
|
||||
class MockPredictionModel(override val uid: String)
|
||||
extends PredictionModel[Vector, MockPredictionModel] {
|
||||
|
||||
def this() = this(Identifiable.randomUID("mockpredictormodel"))
|
||||
|
||||
override def predict(features: Vector): Double =
|
||||
throw new NotImplementedError()
|
||||
|
||||
override def copy(extra: ParamMap): MockPredictionModel =
|
||||
throw new NotImplementedError()
|
||||
}
|
||||
}
|
|
@ -1807,7 +1807,6 @@ class LogisticRegressionSuite
|
|||
.objectiveHistory
|
||||
.sliding(2)
|
||||
.forall(x => x(0) >= x(1)))
|
||||
|
||||
}
|
||||
|
||||
test("binary logistic regression with weighted data") {
|
||||
|
|
Loading…
Reference in a new issue