[SPARK-18570][ML][R] RFormula support * and ^ operators

## What changes were proposed in this pull request?

Added support for `*` and `^` operators, along with expressions within parentheses. New operators just expand to already supported terms, such as;

 - y ~ a * b = y ~ a + b + a : b
 - y ~ (a+b+c)^3 = y ~ a + b + c + a : b + a : c + a :b : c

## How was this patch tested?

Added new unit tests to RFormulaParserSuite

mengxr yanboliang

Closes #24764 from ozancicek/rformula.

Authored-by: ozan <ozancancicekci@gmail.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
ozan 2019-06-04 08:59:30 -05:00 committed by Sean Owen
parent 3ddc26ddd8
commit a38d605d0d
8 changed files with 174 additions and 25 deletions

View file

@ -50,7 +50,7 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
#'
#' @param data SparkDataFrame for training.
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' operators are supported, including '~', '.', ':', '+', '-', '*', and '^'.
#' @param regParam The regularization parameter. Only supports L2 regularization currently.
#' @param maxIter Maximum iteration number.
#' @param tol Convergence tolerance of iterations.

View file

@ -55,7 +55,7 @@ setClass("PowerIterationClustering", slots = list(jobj = "jobj"))
#'
#' @param data a SparkDataFrame for training.
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' operators are supported, including '~', '.', ':', '+', '-', '*', and '^'.
#' Note that the response variable of formula is empty in spark.bisectingKmeans.
#' @param k the desired number of leaf clusters. Must be > 1.
#' The actual number could be smaller if there are no divisible leaf clusters.

View file

@ -44,7 +44,7 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
#'
#' @param data a SparkDataFrame for training.
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' operators are supported, including '~', '.', ':', '+', '-', '*', and '^'.
#' @param family a description of the error distribution and link function to be used in the model.
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at

View file

@ -135,7 +135,7 @@ print.summary.decisionTree <- function(x) {
#'
#' @param data a SparkDataFrame for training.
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' operators are supported, including '~', ':', '+', '-', '*', and '^'.
#' @param type type of model, one of "regression" or "classification", to fit
#' @param maxDepth Maximum depth of the tree (>= 0).
#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing

View file

@ -126,8 +126,9 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with
/**
* :: Experimental ::
* Implements the transforms required for fitting a dataset against an R model formula. Currently
* we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see
* the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
* we support a limited subset of the R operators, including '~', '.', ':', '+', '-', '*' and '^'.
* Also see the R formula docs here:
* http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
*
* The basic operators are:
* - `~` separate target and terms
@ -135,6 +136,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with
* - `-` remove a term, "- 1" means removing intercept
* - `:` interaction (multiplication for numeric values, or binarized categorical values)
* - `.` all columns except target
* - `*` factor crossing, includes the terms and interactions between them
* - `^` factor crossing to a specified degree
*
* Suppose `a` and `b` are double columns, we use the following simple examples
* to illustrate the effect of `RFormula`:
@ -142,6 +145,10 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with
* are coefficients.
* - `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3`
* are coefficients.
* - `y ~ a * b` means model `y ~ w0 + w1 * a + w2 * b + w3 * a * b` where `w0` is the
* intercept and `w1, w2, w3` are coefficients
* - `y ~ (a + b)^2` means model `y ~ w0 + w1 * a + w2 * b + w3 * a * b` where `w0` is the
* intercept and `w1, w2, w3` are coefficients
*
* RFormula produces a vector column of features and a double or string column of label.
* Like when formulas are used in R for linear regression, string input columns will be one-hot

View file

@ -34,11 +34,18 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
def resolve(schema: StructType): ResolvedRFormula = {
val dotTerms = expandDot(schema)
var includedTerms = Seq[Seq[String]]()
val seen = mutable.Set[Set[String]]()
terms.foreach {
case col: ColumnRef =>
includedTerms :+= Seq(col.value)
case ColumnInteraction(cols) =>
includedTerms ++= expandInteraction(schema, cols)
expandInteraction(schema, cols) foreach { t =>
// add equivalent interaction terms only once
if (!seen.contains(t.toSet)) {
includedTerms :+= t
seen += t.toSet
}
}
case Dot =>
includedTerms ++= dotTerms.map(Seq(_))
case Deletion(term: Term) =>
@ -57,8 +64,12 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
case _: Deletion =>
throw new RuntimeException("Deletion terms cannot be nested")
case _: Intercept =>
case _: Terms =>
case EmptyTerm =>
}
case _: Intercept =>
case _: Terms =>
case EmptyTerm =>
}
ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept)
}
@ -144,10 +155,50 @@ private[ml] case class ResolvedRFormula(
* R formula terms. See the R formula docs here for more information:
* http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
*/
private[ml] sealed trait Term
private[ml] sealed trait Term {
/** Default representation of a single Term as a part of summed terms. */
def asTerms: Terms = Terms(Seq(this))
/** Creates a summation term by concatenation of terms. */
def add(other: Term): Term = Terms(this.asTerms.terms ++ other.asTerms.terms)
/**
* Fold by adding deletion terms to the left. Double negation
* doesn't cancel deletion in order not to add extra terms, e.g.
* a - (b - c) = a - Deletion(b) - Deletion(c) = a
*/
def subtract(other: Term): Term = {
other.asTerms.terms.foldLeft(this) {
case (left, right) =>
right match {
case t: Deletion => left.add(t)
case t: Term => left.add(Deletion(t))
}
}
}
/** Default interactions of a Term */
def interact(other: Term): Term = EmptyTerm
}
/** Placeholder term for the result of undefined interactions, e.g. '1:1' or 'a:1' */
private[ml] case object EmptyTerm extends Term
/** A term that may be part of an interaction, e.g. 'x' in 'x:y' */
private[ml] sealed trait InteractableTerm extends Term
private[ml] sealed trait InteractableTerm extends Term {
/** Convert to ColumnInteraction to wrap all interactions. */
def asInteraction: ColumnInteraction = ColumnInteraction(Seq(this))
/** Interactions of interactable terms. */
override def interact(other: Term): Term = other match {
case t: InteractableTerm => this.asInteraction.interact(t.asInteraction)
case t: ColumnInteraction => this.asInteraction.interact(t)
case t: Terms => this.asTerms.interact(t)
case t: Term => t.interact(this) // Deletion or non-interactable term
}
}
/* R formula reference to all available columns, e.g. "." in a formula */
private[ml] case object Dot extends InteractableTerm
@ -156,19 +207,68 @@ private[ml] case object Dot extends InteractableTerm
private[ml] case class ColumnRef(value: String) extends InteractableTerm
/* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */
private[ml] case class ColumnInteraction(terms: Seq[InteractableTerm]) extends Term
private[ml] case class ColumnInteraction(terms: Seq[InteractableTerm]) extends Term {
// Convert to ColumnInteraction and concat terms
override def interact(other: Term): Term = other match {
case t: InteractableTerm => this.interact(t.asInteraction)
case t: ColumnInteraction => ColumnInteraction(terms ++ t.terms)
case t: Terms => this.asTerms.interact(t)
case t: Term => t.interact(this)
}
}
/* R formula intercept toggle, e.g. "+ 0" in a formula */
private[ml] case class Intercept(enabled: Boolean) extends Term
/* R formula deletion of a variable, e.g. "- Species" in a formula */
private[ml] case class Deletion(term: Term) extends Term
private[ml] case class Deletion(term: Term) extends Term {
// Unnest the deletion and interact
override def interact(other: Term): Term = other match {
case Deletion(t) => Deletion(term.interact(t))
case t: Term => Deletion(term.interact(t))
}
}
/* Wrapper for multiple terms in a formula. */
private[ml] case class Terms(terms: Seq[Term]) extends Term {
override def asTerms: Terms = this
override def interact(other: Term): Term = {
val interactions = for {
left <- terms
right <- other.asTerms.terms
} yield left.interact(right)
Terms(interactions)
}
}
/**
* Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.', ':'.
* Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.', ':',
* '*', '^'.
*/
private[ml] object RFormulaParser extends RegexParsers {
private val intercept: Parser[Intercept] =
private def add(left: Term, right: Term) = left.add(right)
private def subtract(left: Term, right: Term) = left.subtract(right)
private def interact(left: Term, right: Term) = left.interact(right)
private def cross(left: Term, right: Term) = left.add(right).add(left.interact(right))
private def power(base: Term, degree: Int): Term = {
val exprs = List.fill(degree)(base)
exprs match {
case Nil => EmptyTerm
case x :: Nil => x
case x :: xs => xs.foldLeft(x)(cross _)
}
}
private val intercept: Parser[Term] =
"([01])".r ^^ { case a => Intercept(a == "1") }
private val columnRef: Parser[ColumnRef] =
@ -178,22 +278,27 @@ private[ml] object RFormulaParser extends RegexParsers {
private val label: Parser[ColumnRef] = columnRef | empty
private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot }
private val dot: Parser[Term] = "\\.".r ^^ { case _ => Dot }
private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef | dot, ":")
private val parens: Parser[Term] = "(" ~> expr <~ ")"
private val term: Parser[Term] = intercept |
interaction ^^ { case terms => ColumnInteraction(terms) } | dot | columnRef
private val term: Parser[Term] = parens | intercept | columnRef | dot
private val terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
case op ~ list => list.foldLeft(List(op)) {
case (left, "+" ~ right) => left ++ Seq(right)
case (left, "-" ~ right) => left ++ Seq(Deletion(right))
}
}
private val pow: Parser[Term] = term ~ "^" ~ "^[1-9]\\d*".r ^^ {
case base ~ "^" ~ degree => power(base, degree.toInt)
} | term
private val interaction: Parser[Term] = pow * (":" ^^^ { interact _ })
private val factor = interaction * ("*" ^^^ { cross _ })
private val sum = factor * ("+" ^^^ { add _ } |
"-" ^^^ { subtract _ })
private val expr = (sum | term)
private val formula: Parser[ParsedRFormula] =
(label ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
(label ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.asTerms.terms) }
def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
case Success(result, _) => result

View file

@ -90,12 +90,48 @@ class RFormulaParserSuite extends SparkFunSuite {
test("parse interactions") {
checkParse("y ~ a:b", "y", Seq("a:b"))
checkParse("y ~ a:b + b:a", "y", Seq("a:b"))
checkParse("y ~ ._a:._x", "y", Seq("._a:._x"))
checkParse("y ~ foo:bar", "y", Seq("foo:bar"))
checkParse("y ~ a : b : c", "y", Seq("a:b:c"))
checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", "c:d", "z"))
}
test("parse factor cross") {
checkParse("y ~ a*b", "y", Seq("a", "b", "a:b"))
checkParse("y ~ a*b + b*a", "y", Seq("a", "b", "a:b"))
checkParse("y ~ ._a*._x", "y", Seq("._a", "._x", "._a:._x"))
checkParse("y ~ foo*bar", "y", Seq("foo", "bar", "foo:bar"))
checkParse("y ~ a * b * c", "y", Seq("a", "b", "a:b", "c", "a:c", "b:c", "a:b:c"))
}
test("interaction distributive") {
checkParse("y ~ (a + b):c", "y", Seq("a:c", "b:c"))
checkParse("y ~ c:(a + b)", "y", Seq("c:a", "c:b"))
}
test("factor cross distributive") {
checkParse("y ~ (a + b)*c", "y", Seq("a", "b", "c", "a:c", "b:c"))
checkParse("y ~ c*(a + b)", "y", Seq("c", "a", "b", "c:a", "c:b"))
}
test("parse power") {
val schema = (new StructType)
.add("a", "int", true)
.add("b", "long", false)
.add("c", "string", true)
.add("d", "string", true)
checkParse("a ~ (a + b)^2", "a", Seq("a", "b", "a:b"))
checkParse("a ~ .^2", "a", Seq("b", "c", "d", "b:c", "b:d", "c:d"), schema)
checkParse("a ~ .^3", "a", Seq("b", "c", "d", "b:c", "b:d", "c:d", "b:c:d"), schema)
checkParse("a ~ .^3-.", "a", Seq("b:c", "b:d", "c:d", "b:c:d"), schema)
}
test("operator precedence") {
checkParse("y ~ a*b:c", "y", Seq("a", "b:c", "a:b:c"))
checkParse("y ~ (a*b):c", "y", Seq("a:c", "b:c", "a:b:c"))
}
test("parse basic interactions with dot") {
val schema = (new StructType)
.add("a", "int", true)

View file

@ -3423,7 +3423,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, HasHandleInvalid,
Implements the transforms required for fitting a dataset against an
R model formula. Currently we support a limited subset of the R
operators, including '~', '.', ':', '+', and '-'. Also see the `R formula docs
operators, including '~', '.', ':', '+', '-', '*', and '^'.
Also see the `R formula docs
<http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html>`_.
>>> df = spark.createDataFrame([