[SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit

## What changes were proposed in this pull request?

1, move cast to `Predictor`
2, and then, remove unnecessary cast
## How was this patch tested?

existing tests

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #15414 from zhengruifeng/move_cast.
This commit is contained in:
Zheng RuiFeng 2016-11-01 10:46:36 -07:00 committed by Joseph K. Bradley
parent 0cba535af3
commit 8ac09108fc
9 changed files with 98 additions and 11 deletions

View file

@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params
/**
* :: DeveloperApi ::
* Abstraction for prediction problems (regression and classification).
* Abstraction for prediction problems (regression and classification). It accepts all NumericType
* labels and will automatically cast it to DoubleType in [[fit()]].
*
* @tparam FeaturesType Type of features.
* E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
@ -87,7 +88,12 @@ abstract class Predictor[
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)
copyValues(train(dataset).setParent(this))
// Cast LabelCol to DoubleType and keep the metadata.
val labelMeta = dataset.schema($(labelCol)).metadata
val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
copyValues(train(casted).setParent(this))
}
override def copy(extra: ParamMap): Learner
@ -121,7 +127,7 @@ abstract class Predictor[
* and put it in an RDD with strong types.
*/
protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
}

View file

@ -71,7 +71,7 @@ abstract class Classifier[
* and put it in an RDD with strong types.
*
* @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
* and features ([[Vector]]). Labels are cast to [[DoubleType]].
* and features ([[Vector]]).
* @param numClasses Number of classes label can take. Labels must be integers in the range
* [0, numClasses).
* @throws SparkException if any label is not an integer >= 0
@ -79,7 +79,7 @@ abstract class Classifier[
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +

View file

@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") (
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
// 2 classes now. This lets us provide a more precise error message.
val oldDataset: RDD[LabeledPoint] =
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label == 0 || label == 1, s"GBTClassifier was given" +
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +

View file

@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") (
LogisticRegressionModel = {
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}

View file

@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") (
// Aggregates term frequencies per label.
// TODO: Calling aggregateByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd
val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
.map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
}.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))(
seqOp = {

View file

@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}

View file

@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] = dataset.select(
col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}

View file

@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext {
import PredictorSuite._
test("should support all NumericType labels and not support other types") {
val df = spark.createDataFrame(Seq(
(0, Vectors.dense(0, 2, 3)),
(1, Vectors.dense(0, 3, 9)),
(0, Vectors.dense(0, 2, 6))
)).toDF("label", "features")
val types =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
val predictor = new MockPredictor()
types.foreach { t =>
predictor.fit(df.select(col("label").cast(t), col("features")))
}
intercept[IllegalArgumentException] {
predictor.fit(df.select(col("label").cast(StringType), col("features")))
}
}
}
object PredictorSuite {
class MockPredictor(override val uid: String)
extends Predictor[Vector, MockPredictor, MockPredictionModel] {
def this() = this(Identifiable.randomUID("mockpredictor"))
override def train(dataset: Dataset[_]): MockPredictionModel = {
require(dataset.schema("label").dataType == DoubleType)
new MockPredictionModel(uid)
}
override def copy(extra: ParamMap): MockPredictor =
throw new NotImplementedError()
}
class MockPredictionModel(override val uid: String)
extends PredictionModel[Vector, MockPredictionModel] {
def this() = this(Identifiable.randomUID("mockpredictormodel"))
override def predict(features: Vector): Double =
throw new NotImplementedError()
override def copy(extra: ParamMap): MockPredictionModel =
throw new NotImplementedError()
}
}

View file

@ -1807,7 +1807,6 @@ class LogisticRegressionSuite
.objectiveHistory
.sliding(2)
.forall(x => x(0) >= x(1)))
}
test("binary logistic regression with weighted data") {