[SPARK-7404] [ML] Add RegressionEvaluator to spark.ml
Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #6344 from harsha2010/SPARK-7404 and squashes the following commits: 16b9d77 [Ram Sriharsha] consistent naming 7f100b6 [Ram Sriharsha] cleanup c46044d [Ram Sriharsha] Merge with Master + Code Review Fixes 188fa0a [Ram Sriharsha] Merge branch 'master' into SPARK-7404 f5b6a4c [Ram Sriharsha] cleanup doc 97beca5 [Ram Sriharsha] update test to use R packages 32dd310 [Ram Sriharsha] fix indentation f93b812 [Ram Sriharsha] fix test 1b6ebb3 [Ram Sriharsha] [SPARK-7404][ml] Add RegressionEvaluator to spark.ml
This commit is contained in:
parent
3b68cb0430
commit
f490b3b4c7
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.evaluation
|
||||
|
||||
import org.apache.spark.annotation.AlphaComponent
|
||||
import org.apache.spark.ml.param.{Param, ParamValidators}
|
||||
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
|
||||
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
|
||||
import org.apache.spark.mllib.evaluation.RegressionMetrics
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.types.DoubleType
|
||||
|
||||
/**
|
||||
* :: AlphaComponent ::
|
||||
*
|
||||
* Evaluator for regression, which expects two input columns: prediction and label.
|
||||
*/
|
||||
@AlphaComponent
|
||||
class RegressionEvaluator(override val uid: String)
|
||||
extends Evaluator with HasPredictionCol with HasLabelCol {
|
||||
|
||||
def this() = this(Identifiable.randomUID("regEval"))
|
||||
|
||||
/**
|
||||
* param for metric name in evaluation
|
||||
* @group param
|
||||
*/
|
||||
val metricName: Param[String] = {
|
||||
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
|
||||
new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams)
|
||||
}
|
||||
|
||||
/** @group getParam */
|
||||
def getMetricName: String = $(metricName)
|
||||
|
||||
/** @group setParam */
|
||||
def setMetricName(value: String): this.type = set(metricName, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setPredictionCol(value: String): this.type = set(predictionCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setLabelCol(value: String): this.type = set(labelCol, value)
|
||||
|
||||
setDefault(metricName -> "rmse")
|
||||
|
||||
override def evaluate(dataset: DataFrame): Double = {
|
||||
val schema = dataset.schema
|
||||
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
|
||||
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
|
||||
|
||||
val predictionAndLabels = dataset.select($(predictionCol), $(labelCol))
|
||||
.map { case Row(prediction: Double, label: Double) =>
|
||||
(prediction, label)
|
||||
}
|
||||
val metrics = new RegressionMetrics(predictionAndLabels)
|
||||
val metric = $(metricName) match {
|
||||
case "rmse" =>
|
||||
metrics.rootMeanSquaredError
|
||||
case "mse" =>
|
||||
metrics.meanSquaredError
|
||||
case "r2" =>
|
||||
metrics.r2
|
||||
case "mae" =>
|
||||
metrics.meanAbsoluteError
|
||||
}
|
||||
metric
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.evaluation
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.ml.regression.LinearRegression
|
||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
|
||||
class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
|
||||
|
||||
test("Regression Evaluator: default params") {
|
||||
/**
|
||||
* Here is the instruction describing how to export the test data into CSV format
|
||||
* so we can validate the metrics compared with R's mmetric package.
|
||||
*
|
||||
* import org.apache.spark.mllib.util.LinearDataGenerator
|
||||
* val data = sc.parallelize(LinearDataGenerator.generateLinearInput(6.3,
|
||||
* Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1))
|
||||
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1))
|
||||
* .saveAsTextFile("path")
|
||||
*/
|
||||
val dataset = sqlContext.createDataFrame(
|
||||
sc.parallelize(LinearDataGenerator.generateLinearInput(
|
||||
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
|
||||
/**
|
||||
* Using the following R code to load the data, train the model and evaluate metrics.
|
||||
*
|
||||
* > library("glmnet")
|
||||
* > library("rminer")
|
||||
* > data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
|
||||
* > features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
|
||||
* > label <- as.numeric(data$V1)
|
||||
* > model <- glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)
|
||||
* > rmse <- mmetric(label, predict(model, features), metric='RMSE')
|
||||
* > mae <- mmetric(label, predict(model, features), metric='MAE')
|
||||
* > r2 <- mmetric(label, predict(model, features), metric='R2')
|
||||
*/
|
||||
val trainer = new LinearRegression
|
||||
val model = trainer.fit(dataset)
|
||||
val predictions = model.transform(dataset)
|
||||
|
||||
// default = rmse
|
||||
val evaluator = new RegressionEvaluator()
|
||||
assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001)
|
||||
|
||||
// r2 score
|
||||
evaluator.setMetricName("r2")
|
||||
assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001)
|
||||
|
||||
// mae
|
||||
evaluator.setMetricName("mae")
|
||||
assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue