[SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula
Preview: ``` > summary(m) features coefficients 1 (Intercept) 1.6765001 2 Sepal_Length 0.3498801 3 Species.versicolor -0.9833885 4 Species.virginica -1.0075104 ``` Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit cc mengxr Author: Eric Liang <ekl@databricks.com> Closes #7771 from ericl/summary and squashes the following commits: ccd54c3 [Eric Liang] second pass a5ca93b [Eric Liang] comments 2772111 [Eric Liang] clean up 70483ef [Eric Liang] fix test 7c247d4 [Eric Liang] Merge branch 'master' into summary 3c55024 [Eric Liang] working 8c539aa [Eric Liang] first pass
This commit is contained in:
parent
be7be6d4c7
commit
e7905a9395
|
@ -12,7 +12,8 @@ export("print.jobj")
|
|||
|
||||
# MLlib integration
|
||||
exportMethods("glm",
|
||||
"predict")
|
||||
"predict",
|
||||
"summary")
|
||||
|
||||
# Job group lifecycle management methods
|
||||
export("setJobGroup",
|
||||
|
|
|
@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"),
|
|||
function(object, newData) {
|
||||
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
|
||||
})
|
||||
|
||||
#' Get the summary of a model
|
||||
#'
|
||||
#' Returns the summary of a model produced by glm(), similarly to R's summary().
|
||||
#'
|
||||
#' @param model A fitted MLlib model
|
||||
#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
|
||||
#' summary.glm for more information.
|
||||
#' @rdname glm
|
||||
#' @export
|
||||
#' @examples
|
||||
#'\dontrun{
|
||||
#' model <- glm(y ~ x, trainingData)
|
||||
#' summary(model)
|
||||
#'}
|
||||
setMethod("summary", signature(object = "PipelineModel"),
|
||||
function(object) {
|
||||
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
|
||||
"getModelFeatures", object@model)
|
||||
weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
|
||||
"getModelWeights", object@model)
|
||||
coefficients <- as.matrix(unlist(weights))
|
||||
colnames(coefficients) <- c("Estimate")
|
||||
rownames(coefficients) <- unlist(features)
|
||||
return(list(coefficients = coefficients))
|
||||
})
|
||||
|
|
|
@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", {
|
|||
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
|
||||
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
|
||||
})
|
||||
|
||||
test_that("summary coefficients match with native glm", {
|
||||
training <- createDataFrame(sqlContext, iris)
|
||||
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
|
||||
coefs <- as.vector(stats$coefficients)
|
||||
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
|
||||
expect_true(all(abs(rCoefs - coefs) < 1e-6))
|
||||
expect_true(all(
|
||||
as.character(stats$features) ==
|
||||
c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
|
||||
})
|
||||
|
|
|
@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
|
|||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
val is = "_is_"
|
||||
val inputColName = $(inputCol)
|
||||
val outputColName = $(outputCol)
|
||||
|
||||
|
@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
|
|||
val outputAttrNames: Option[Array[String]] = inputAttr match {
|
||||
case nominal: NominalAttribute =>
|
||||
if (nominal.values.isDefined) {
|
||||
nominal.values.map(_.map(v => inputColName + is + v))
|
||||
nominal.values
|
||||
} else if (nominal.numValues.isDefined) {
|
||||
nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
|
||||
nominal.numValues.map(n => Array.tabulate(n)(_.toString))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
case binary: BinaryAttribute =>
|
||||
if (binary.values.isDefined) {
|
||||
binary.values.map(_.map(v => inputColName + is + v))
|
||||
binary.values
|
||||
} else {
|
||||
Some(Array.tabulate(2)(i => inputColName + is + i))
|
||||
Some(Array.tabulate(2)(_.toString))
|
||||
}
|
||||
case _: NumericAttribute =>
|
||||
throw new RuntimeException(
|
||||
|
@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
|
|||
|
||||
override def transform(dataset: DataFrame): DataFrame = {
|
||||
// schema transformation
|
||||
val is = "_is_"
|
||||
val inputColName = $(inputCol)
|
||||
val outputColName = $(outputCol)
|
||||
val shouldDropLast = $(dropLast)
|
||||
|
@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
|
|||
math.max(m0, m1)
|
||||
}
|
||||
).toInt + 1
|
||||
val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
|
||||
val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
|
||||
val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
|
||||
val outputAttrs: Array[Attribute] =
|
||||
filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.util.parsing.combinator.RegexParsers
|
||||
|
||||
|
@ -91,11 +92,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
|
|||
// TODO(ekl) add support for feature interactions
|
||||
val encoderStages = ArrayBuffer[PipelineStage]()
|
||||
val tempColumns = ArrayBuffer[String]()
|
||||
val takenNames = mutable.Set(dataset.columns: _*)
|
||||
val encodedTerms = resolvedFormula.terms.map { term =>
|
||||
dataset.schema(term) match {
|
||||
case column if column.dataType == StringType =>
|
||||
val indexCol = term + "_idx_" + uid
|
||||
val encodedCol = term + "_onehot_" + uid
|
||||
val encodedCol = {
|
||||
var tmp = term
|
||||
while (takenNames.contains(tmp)) {
|
||||
tmp += "_"
|
||||
}
|
||||
tmp
|
||||
}
|
||||
takenNames.add(indexCol)
|
||||
takenNames.add(encodedCol)
|
||||
encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
|
||||
encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
|
||||
tempColumns += indexCol
|
||||
|
|
|
@ -17,9 +17,10 @@
|
|||
|
||||
package org.apache.spark.ml.api.r
|
||||
|
||||
import org.apache.spark.ml.attribute._
|
||||
import org.apache.spark.ml.feature.RFormula
|
||||
import org.apache.spark.ml.classification.LogisticRegression
|
||||
import org.apache.spark.ml.regression.LinearRegression
|
||||
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
|
||||
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
|
@ -44,4 +45,26 @@ private[r] object SparkRWrappers {
|
|||
val pipeline = new Pipeline().setStages(Array(formula, estimator))
|
||||
pipeline.fit(df)
|
||||
}
|
||||
|
||||
def getModelWeights(model: PipelineModel): Array[Double] = {
|
||||
model.stages.last match {
|
||||
case m: LinearRegressionModel =>
|
||||
Array(m.intercept) ++ m.weights.toArray
|
||||
case _: LogisticRegressionModel =>
|
||||
throw new UnsupportedOperationException(
|
||||
"No weights available for LogisticRegressionModel") // SPARK-9492
|
||||
}
|
||||
}
|
||||
|
||||
def getModelFeatures(model: PipelineModel): Array[String] = {
|
||||
model.stages.last match {
|
||||
case m: LinearRegressionModel =>
|
||||
val attrs = AttributeGroup.fromStructField(
|
||||
m.summary.predictions.schema(m.summary.featuresCol))
|
||||
Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
|
||||
case _: LogisticRegressionModel =>
|
||||
throw new UnsupportedOperationException(
|
||||
"No features names available for LogisticRegressionModel") // SPARK-9492
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.functions.{col, udf}
|
||||
import org.apache.spark.sql.types.StructField
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.util.StatCounter
|
||||
|
||||
|
@ -146,9 +147,10 @@ class LinearRegression(override val uid: String)
|
|||
|
||||
val model = new LinearRegressionModel(uid, weights, intercept)
|
||||
val trainingSummary = new LinearRegressionTrainingSummary(
|
||||
model.transform(dataset).select($(predictionCol), $(labelCol)),
|
||||
model.transform(dataset),
|
||||
$(predictionCol),
|
||||
$(labelCol),
|
||||
$(featuresCol),
|
||||
Array(0D))
|
||||
return copyValues(model.setSummary(trainingSummary))
|
||||
}
|
||||
|
@ -221,9 +223,10 @@ class LinearRegression(override val uid: String)
|
|||
|
||||
val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
|
||||
val trainingSummary = new LinearRegressionTrainingSummary(
|
||||
model.transform(dataset).select($(predictionCol), $(labelCol)),
|
||||
model.transform(dataset),
|
||||
$(predictionCol),
|
||||
$(labelCol),
|
||||
$(featuresCol),
|
||||
objectiveHistory)
|
||||
model.setSummary(trainingSummary)
|
||||
}
|
||||
|
@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] (
|
|||
predictions: DataFrame,
|
||||
predictionCol: String,
|
||||
labelCol: String,
|
||||
val featuresCol: String,
|
||||
val objectiveHistory: Array[Double])
|
||||
extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
|
||||
|
||||
|
|
|
@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val output = encoder.transform(df)
|
||||
val group = AttributeGroup.fromStructField(output.schema("encoded"))
|
||||
assert(group.size === 2)
|
||||
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
|
||||
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
|
||||
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
|
||||
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
|
||||
}
|
||||
|
||||
test("input column without ML attribute") {
|
||||
|
@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val output = encoder.transform(df)
|
||||
val group = AttributeGroup.fromStructField(output.schema("encoded"))
|
||||
assert(group.size === 2)
|
||||
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
|
||||
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
|
||||
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
|
||||
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.attribute._
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
|
@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(result.schema.toString == resultSchema.toString)
|
||||
assert(result.collect() === expected.collect())
|
||||
}
|
||||
|
||||
test("attribute generation") {
|
||||
val formula = new RFormula().setFormula("id ~ a + b")
|
||||
val original = sqlContext.createDataFrame(
|
||||
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
|
||||
).toDF("id", "a", "b")
|
||||
val model = formula.fit(original)
|
||||
val result = model.transform(original)
|
||||
val attrs = AttributeGroup.fromStructField(result.schema("features"))
|
||||
val expectedAttrs = new AttributeGroup(
|
||||
"features",
|
||||
Array(
|
||||
new BinaryAttribute(Some("a__bar"), Some(1)),
|
||||
new BinaryAttribute(Some("a__foo"), Some(2)),
|
||||
new NumericAttribute(Some("b"), Some(3))))
|
||||
assert(attrs === expectedAttrs)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue