[SPARK-18401][SPARKR][ML] SparkR random forest should support output original label.
## What changes were proposed in this pull request? SparkR ```spark.randomForest``` classification prediction should output original label rather than the indexed label. This issue is very similar with [SPARK-18291](https://issues.apache.org/jira/browse/SPARK-18291). ## How was this patch tested? Add unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15842 from yanboliang/spark-18401.
This commit is contained in:
parent
a3356343cb
commit
5ddf69470b
|
@ -935,6 +935,10 @@ test_that("spark.randomForest Classification", {
|
||||||
expect_equal(stats$numTrees, 20)
|
expect_equal(stats$numTrees, 20)
|
||||||
expect_error(capture.output(stats), NA)
|
expect_error(capture.output(stats), NA)
|
||||||
expect_true(length(capture.output(stats)) > 6)
|
expect_true(length(capture.output(stats)) > 6)
|
||||||
|
# Test string prediction values
|
||||||
|
predictions <- collect(predict(model, data))$prediction
|
||||||
|
expect_equal(length(grep("setosa", predictions)), 50)
|
||||||
|
expect_equal(length(grep("versicolor", predictions)), 50)
|
||||||
|
|
||||||
modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp")
|
modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp")
|
||||||
write.ml(model, modelPath)
|
write.ml(model, modelPath)
|
||||||
|
@ -947,6 +951,26 @@ test_that("spark.randomForest Classification", {
|
||||||
expect_equal(stats$numClasses, stats2$numClasses)
|
expect_equal(stats$numClasses, stats2$numClasses)
|
||||||
|
|
||||||
unlink(modelPath)
|
unlink(modelPath)
|
||||||
|
|
||||||
|
# Test numeric response variable
|
||||||
|
labelToIndex <- function(species) {
|
||||||
|
switch(as.character(species),
|
||||||
|
setosa = 0.0,
|
||||||
|
versicolor = 1.0,
|
||||||
|
virginica = 2.0
|
||||||
|
)
|
||||||
|
}
|
||||||
|
iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
|
||||||
|
data <- suppressWarnings(createDataFrame(iris[-5]))
|
||||||
|
model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification",
|
||||||
|
maxDepth = 5, maxBins = 16)
|
||||||
|
stats <- summary(model)
|
||||||
|
expect_equal(stats$numFeatures, 2)
|
||||||
|
expect_equal(stats$numTrees, 20)
|
||||||
|
# Test numeric prediction values
|
||||||
|
predictions <- collect(predict(model, data))$prediction
|
||||||
|
expect_equal(length(grep("1.0", predictions)), 50)
|
||||||
|
expect_equal(length(grep("2.0", predictions)), 50)
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("spark.gbt", {
|
test_that("spark.gbt", {
|
||||||
|
|
|
@ -23,9 +23,9 @@ import org.json4s.JsonDSL._
|
||||||
import org.json4s.jackson.JsonMethods._
|
import org.json4s.jackson.JsonMethods._
|
||||||
|
|
||||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
import org.apache.spark.ml.attribute.AttributeGroup
|
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
|
||||||
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
|
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
|
||||||
import org.apache.spark.ml.feature.RFormula
|
import org.apache.spark.ml.feature.{IndexToString, RFormula}
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
@ -35,6 +35,8 @@ private[r] class RandomForestClassifierWrapper private (
|
||||||
val formula: String,
|
val formula: String,
|
||||||
val features: Array[String]) extends MLWritable {
|
val features: Array[String]) extends MLWritable {
|
||||||
|
|
||||||
|
import RandomForestClassifierWrapper._
|
||||||
|
|
||||||
private val rfcModel: RandomForestClassificationModel =
|
private val rfcModel: RandomForestClassificationModel =
|
||||||
pipeline.stages(1).asInstanceOf[RandomForestClassificationModel]
|
pipeline.stages(1).asInstanceOf[RandomForestClassificationModel]
|
||||||
|
|
||||||
|
@ -46,7 +48,9 @@ private[r] class RandomForestClassifierWrapper private (
|
||||||
def summary: String = rfcModel.toDebugString
|
def summary: String = rfcModel.toDebugString
|
||||||
|
|
||||||
def transform(dataset: Dataset[_]): DataFrame = {
|
def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
pipeline.transform(dataset).drop(rfcModel.getFeaturesCol)
|
pipeline.transform(dataset)
|
||||||
|
.drop(PREDICTED_LABEL_INDEX_COL)
|
||||||
|
.drop(rfcModel.getFeaturesCol)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def write: MLWriter = new
|
override def write: MLWriter = new
|
||||||
|
@ -54,6 +58,10 @@ private[r] class RandomForestClassifierWrapper private (
|
||||||
}
|
}
|
||||||
|
|
||||||
private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] {
|
private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] {
|
||||||
|
|
||||||
|
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
|
||||||
|
val PREDICTED_LABEL_COL = "prediction"
|
||||||
|
|
||||||
def fit( // scalastyle:ignore
|
def fit( // scalastyle:ignore
|
||||||
data: DataFrame,
|
data: DataFrame,
|
||||||
formula: String,
|
formula: String,
|
||||||
|
@ -73,6 +81,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
|
||||||
|
|
||||||
val rFormula = new RFormula()
|
val rFormula = new RFormula()
|
||||||
.setFormula(formula)
|
.setFormula(formula)
|
||||||
|
.setForceIndexLabel(true)
|
||||||
RWrapperUtils.checkDataColumns(rFormula, data)
|
RWrapperUtils.checkDataColumns(rFormula, data)
|
||||||
val rFormulaModel = rFormula.fit(data)
|
val rFormulaModel = rFormula.fit(data)
|
||||||
|
|
||||||
|
@ -82,6 +91,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
|
||||||
.attributes.get
|
.attributes.get
|
||||||
val features = featureAttrs.map(_.name.get)
|
val features = featureAttrs.map(_.name.get)
|
||||||
|
|
||||||
|
// get label names from output schema
|
||||||
|
val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
|
||||||
|
.asInstanceOf[NominalAttribute]
|
||||||
|
val labels = labelAttr.values.get
|
||||||
|
|
||||||
// assemble and fit the pipeline
|
// assemble and fit the pipeline
|
||||||
val rfc = new RandomForestClassifier()
|
val rfc = new RandomForestClassifier()
|
||||||
.setMaxDepth(maxDepth)
|
.setMaxDepth(maxDepth)
|
||||||
|
@ -97,10 +111,16 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
|
||||||
.setCacheNodeIds(cacheNodeIds)
|
.setCacheNodeIds(cacheNodeIds)
|
||||||
.setProbabilityCol(probabilityCol)
|
.setProbabilityCol(probabilityCol)
|
||||||
.setFeaturesCol(rFormula.getFeaturesCol)
|
.setFeaturesCol(rFormula.getFeaturesCol)
|
||||||
|
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
|
||||||
if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
|
if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
|
||||||
|
|
||||||
|
val idxToStr = new IndexToString()
|
||||||
|
.setInputCol(PREDICTED_LABEL_INDEX_COL)
|
||||||
|
.setOutputCol(PREDICTED_LABEL_COL)
|
||||||
|
.setLabels(labels)
|
||||||
|
|
||||||
val pipeline = new Pipeline()
|
val pipeline = new Pipeline()
|
||||||
.setStages(Array(rFormulaModel, rfc))
|
.setStages(Array(rFormulaModel, rfc, idxToStr))
|
||||||
.fit(data)
|
.fit(data)
|
||||||
|
|
||||||
new RandomForestClassifierWrapper(pipeline, formula, features)
|
new RandomForestClassifierWrapper(pipeline, formula, features)
|
||||||
|
|
Loading…
Reference in a new issue