[SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula

Preview:

```
> summary(m)
            features coefficients
1        (Intercept)    1.6765001
2       Sepal_Length    0.3498801
3 Species.versicolor   -0.9833885
4  Species.virginica   -1.0075104

```

Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit

cc mengxr

Author: Eric Liang <ekl@databricks.com>

Closes #7771 from ericl/summary and squashes the following commits:

ccd54c3 [Eric Liang] second pass
a5ca93b [Eric Liang] comments
2772111 [Eric Liang] clean up
70483ef [Eric Liang] fix test
7c247d4 [Eric Liang] Merge branch 'master' into summary
3c55024 [Eric Liang] working
8c539aa [Eric Liang] first pass
This commit is contained in:
Eric Liang 2015-07-30 16:15:43 -07:00 committed by Xiangrui Meng
parent be7be6d4c7
commit e7905a9395
9 changed files with 108 additions and 17 deletions

View file

@ -12,7 +12,8 @@ export("print.jobj")
# MLlib integration
exportMethods("glm",
"predict")
"predict",
"summary")
# Job group lifecycle management methods
export("setJobGroup",

View file

@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})
#' Get the summary of a model
#'
#' Returns the summary of a model produced by glm(), similarly to R's summary().
#'
#' @param model A fitted MLlib model
#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
#' summary.glm for more information.
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#'}
setMethod("summary", signature(object = "PipelineModel"),
function(object) {
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelFeatures", object@model)
weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelWeights", object@model)
coefficients <- as.matrix(unlist(weights))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
})

View file

@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", {
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
coefs <- as.vector(stats$coefficients)
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
expect_true(all(abs(rCoefs - coefs) < 1e-6))
expect_true(all(
as.character(stats$features) ==
c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
})

View file

@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transformSchema(schema: StructType): StructType = {
val is = "_is_"
val inputColName = $(inputCol)
val outputColName = $(outputCol)
@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
val outputAttrNames: Option[Array[String]] = inputAttr match {
case nominal: NominalAttribute =>
if (nominal.values.isDefined) {
nominal.values.map(_.map(v => inputColName + is + v))
nominal.values
} else if (nominal.numValues.isDefined) {
nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
nominal.numValues.map(n => Array.tabulate(n)(_.toString))
} else {
None
}
case binary: BinaryAttribute =>
if (binary.values.isDefined) {
binary.values.map(_.map(v => inputColName + is + v))
binary.values
} else {
Some(Array.tabulate(2)(i => inputColName + is + i))
Some(Array.tabulate(2)(_.toString))
}
case _: NumericAttribute =>
throw new RuntimeException(
@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
override def transform(dataset: DataFrame): DataFrame = {
// schema transformation
val is = "_is_"
val inputColName = $(inputCol)
val outputColName = $(outputCol)
val shouldDropLast = $(dropLast)
@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
math.max(m0, m1)
}
).toInt + 1
val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
val outputAttrs: Array[Attribute] =
filtered.map(name => BinaryAttribute.defaultAttr.withName(name))

View file

@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.parsing.combinator.RegexParsers
@ -91,11 +92,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
// TODO(ekl) add support for feature interactions
val encoderStages = ArrayBuffer[PipelineStage]()
val tempColumns = ArrayBuffer[String]()
val takenNames = mutable.Set(dataset.columns: _*)
val encodedTerms = resolvedFormula.terms.map { term =>
dataset.schema(term) match {
case column if column.dataType == StringType =>
val indexCol = term + "_idx_" + uid
val encodedCol = term + "_onehot_" + uid
val encodedCol = {
var tmp = term
while (takenNames.contains(tmp)) {
tmp += "_"
}
tmp
}
takenNames.add(indexCol)
takenNames.add(encodedCol)
encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
tempColumns += indexCol

View file

@ -17,9 +17,10 @@
package org.apache.spark.ml.api.r
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.DataFrame
@ -44,4 +45,26 @@ private[r] object SparkRWrappers {
val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
}
def getModelWeights(model: PipelineModel): Array[Double] = {
model.stages.last match {
case m: LinearRegressionModel =>
Array(m.intercept) ++ m.weights.toArray
case _: LogisticRegressionModel =>
throw new UnsupportedOperationException(
"No weights available for LogisticRegressionModel") // SPARK-9492
}
}
def getModelFeatures(model: PipelineModel): Array[String] = {
model.stages.last match {
case m: LinearRegressionModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
case _: LogisticRegressionModel =>
throw new UnsupportedOperationException(
"No features names available for LogisticRegressionModel") // SPARK-9492
}
}
}

View file

@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructField
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter
@ -146,9 +147,10 @@ class LinearRegression(override val uid: String)
val model = new LinearRegressionModel(uid, weights, intercept)
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset).select($(predictionCol), $(labelCol)),
model.transform(dataset),
$(predictionCol),
$(labelCol),
$(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
}
@ -221,9 +223,10 @@ class LinearRegression(override val uid: String)
val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset).select($(predictionCol), $(labelCol)),
model.transform(dataset),
$(predictionCol),
$(labelCol),
$(featuresCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
val featuresCol: String,
val objectiveHistory: Array[Double])
extends LinearRegressionSummary(predictions, predictionCol, labelCol) {

View file

@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = encoder.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
}
test("input column without ML attribute") {
@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = encoder.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
}
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
}
test("attribute generation") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Array(
new BinaryAttribute(Some("a__bar"), Some(1)),
new BinaryAttribute(Some("a__foo"), Some(2)),
new NumericAttribute(Some("b"), Some(3))))
assert(attrs === expectedAttrs)
}
}