diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 2964fdeff0..3ad824e1e6 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -50,7 +50,7 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' #' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. +#' operators are supported, including '~', '.', ':', '+', '-', '*', and '^'. #' @param regParam The regularization parameter. Only supports L2 regularization currently. #' @param maxIter Maximum iteration number. #' @param tol Convergence tolerance of iterations. diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 9b32b71d34..8bc1535346 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -55,7 +55,7 @@ setClass("PowerIterationClustering", slots = list(jobj = "jobj")) #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. +#' operators are supported, including '~', '.', ':', '+', '-', '*', and '^'. #' Note that the response variable of formula is empty in spark.bisectingKmeans. #' @param k the desired number of leaf clusters. Must be > 1. #' The actual number could be smaller if there are no divisible leaf clusters. diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 95c1a29905..4fabe9a006 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -44,7 +44,7 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. +#' operators are supported, including '~', '.', ':', '+', '-', '*', and '^'. #' @param family a description of the error distribution and link function to be used in the model. #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 9844061cfd..ff16b43621 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -135,7 +135,7 @@ print.summary.decisionTree <- function(x) { #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', ':', '+', and '-'. +#' operators are supported, including '~', ':', '+', '-', '*', and '^'. #' @param type type of model, one of "regression" or "classification", to fit #' @param maxDepth Maximum depth of the tree (>= 0). #' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d7eb13772a..ec8f7031ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -126,8 +126,9 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with /** * :: Experimental :: * Implements the transforms required for fitting a dataset against an R model formula. Currently - * we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see - * the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + * we support a limited subset of the R operators, including '~', '.', ':', '+', '-', '*' and '^'. + * Also see the R formula docs here: + * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html * * The basic operators are: * - `~` separate target and terms @@ -135,6 +136,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with * - `-` remove a term, "- 1" means removing intercept * - `:` interaction (multiplication for numeric values, or binarized categorical values) * - `.` all columns except target + * - `*` factor crossing, includes the terms and interactions between them + * - `^` factor crossing to a specified degree * * Suppose `a` and `b` are double columns, we use the following simple examples * to illustrate the effect of `RFormula`: @@ -142,6 +145,10 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with * are coefficients. * - `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3` * are coefficients. + * - `y ~ a * b` means model `y ~ w0 + w1 * a + w2 * b + w3 * a * b` where `w0` is the + * intercept and `w1, w2, w3` are coefficients + * - `y ~ (a + b)^2` means model `y ~ w0 + w1 * a + w2 * b + w3 * a * b` where `w0` is the + * intercept and `w1, w2, w3` are coefficients * * RFormula produces a vector column of features and a double or string column of label. * Like when formulas are used in R for linear regression, string input columns will be one-hot diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 32835fb3aa..dbbfd8f329 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -34,11 +34,18 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { def resolve(schema: StructType): ResolvedRFormula = { val dotTerms = expandDot(schema) var includedTerms = Seq[Seq[String]]() + val seen = mutable.Set[Set[String]]() terms.foreach { case col: ColumnRef => includedTerms :+= Seq(col.value) case ColumnInteraction(cols) => - includedTerms ++= expandInteraction(schema, cols) + expandInteraction(schema, cols) foreach { t => + // add equivalent interaction terms only once + if (!seen.contains(t.toSet)) { + includedTerms :+= t + seen += t.toSet + } + } case Dot => includedTerms ++= dotTerms.map(Seq(_)) case Deletion(term: Term) => @@ -57,8 +64,12 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { case _: Deletion => throw new RuntimeException("Deletion terms cannot be nested") case _: Intercept => + case _: Terms => + case EmptyTerm => } case _: Intercept => + case _: Terms => + case EmptyTerm => } ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept) } @@ -144,10 +155,50 @@ private[ml] case class ResolvedRFormula( * R formula terms. See the R formula docs here for more information: * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ -private[ml] sealed trait Term +private[ml] sealed trait Term { + + /** Default representation of a single Term as a part of summed terms. */ + def asTerms: Terms = Terms(Seq(this)) + + /** Creates a summation term by concatenation of terms. */ + def add(other: Term): Term = Terms(this.asTerms.terms ++ other.asTerms.terms) + + /** + * Fold by adding deletion terms to the left. Double negation + * doesn't cancel deletion in order not to add extra terms, e.g. + * a - (b - c) = a - Deletion(b) - Deletion(c) = a + */ + def subtract(other: Term): Term = { + other.asTerms.terms.foldLeft(this) { + case (left, right) => + right match { + case t: Deletion => left.add(t) + case t: Term => left.add(Deletion(t)) + } + } + } + + /** Default interactions of a Term */ + def interact(other: Term): Term = EmptyTerm +} + +/** Placeholder term for the result of undefined interactions, e.g. '1:1' or 'a:1' */ +private[ml] case object EmptyTerm extends Term /** A term that may be part of an interaction, e.g. 'x' in 'x:y' */ -private[ml] sealed trait InteractableTerm extends Term +private[ml] sealed trait InteractableTerm extends Term { + + /** Convert to ColumnInteraction to wrap all interactions. */ + def asInteraction: ColumnInteraction = ColumnInteraction(Seq(this)) + + /** Interactions of interactable terms. */ + override def interact(other: Term): Term = other match { + case t: InteractableTerm => this.asInteraction.interact(t.asInteraction) + case t: ColumnInteraction => this.asInteraction.interact(t) + case t: Terms => this.asTerms.interact(t) + case t: Term => t.interact(this) // Deletion or non-interactable term + } +} /* R formula reference to all available columns, e.g. "." in a formula */ private[ml] case object Dot extends InteractableTerm @@ -156,19 +207,68 @@ private[ml] case object Dot extends InteractableTerm private[ml] case class ColumnRef(value: String) extends InteractableTerm /* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */ -private[ml] case class ColumnInteraction(terms: Seq[InteractableTerm]) extends Term +private[ml] case class ColumnInteraction(terms: Seq[InteractableTerm]) extends Term { + + // Convert to ColumnInteraction and concat terms + override def interact(other: Term): Term = other match { + case t: InteractableTerm => this.interact(t.asInteraction) + case t: ColumnInteraction => ColumnInteraction(terms ++ t.terms) + case t: Terms => this.asTerms.interact(t) + case t: Term => t.interact(this) + } +} /* R formula intercept toggle, e.g. "+ 0" in a formula */ private[ml] case class Intercept(enabled: Boolean) extends Term /* R formula deletion of a variable, e.g. "- Species" in a formula */ -private[ml] case class Deletion(term: Term) extends Term +private[ml] case class Deletion(term: Term) extends Term { + + // Unnest the deletion and interact + override def interact(other: Term): Term = other match { + case Deletion(t) => Deletion(term.interact(t)) + case t: Term => Deletion(term.interact(t)) + } +} + +/* Wrapper for multiple terms in a formula. */ +private[ml] case class Terms(terms: Seq[Term]) extends Term { + + override def asTerms: Terms = this + + override def interact(other: Term): Term = { + val interactions = for { + left <- terms + right <- other.asTerms.terms + } yield left.interact(right) + Terms(interactions) + } +} /** - * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.', ':'. + * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.', ':', + * '*', '^'. */ private[ml] object RFormulaParser extends RegexParsers { - private val intercept: Parser[Intercept] = + + private def add(left: Term, right: Term) = left.add(right) + + private def subtract(left: Term, right: Term) = left.subtract(right) + + private def interact(left: Term, right: Term) = left.interact(right) + + private def cross(left: Term, right: Term) = left.add(right).add(left.interact(right)) + + private def power(base: Term, degree: Int): Term = { + val exprs = List.fill(degree)(base) + exprs match { + case Nil => EmptyTerm + case x :: Nil => x + case x :: xs => xs.foldLeft(x)(cross _) + } + } + + private val intercept: Parser[Term] = "([01])".r ^^ { case a => Intercept(a == "1") } private val columnRef: Parser[ColumnRef] = @@ -178,22 +278,27 @@ private[ml] object RFormulaParser extends RegexParsers { private val label: Parser[ColumnRef] = columnRef | empty - private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot } + private val dot: Parser[Term] = "\\.".r ^^ { case _ => Dot } - private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef | dot, ":") + private val parens: Parser[Term] = "(" ~> expr <~ ")" - private val term: Parser[Term] = intercept | - interaction ^^ { case terms => ColumnInteraction(terms) } | dot | columnRef + private val term: Parser[Term] = parens | intercept | columnRef | dot - private val terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { - case op ~ list => list.foldLeft(List(op)) { - case (left, "+" ~ right) => left ++ Seq(right) - case (left, "-" ~ right) => left ++ Seq(Deletion(right)) - } - } + private val pow: Parser[Term] = term ~ "^" ~ "^[1-9]\\d*".r ^^ { + case base ~ "^" ~ degree => power(base, degree.toInt) + } | term + + private val interaction: Parser[Term] = pow * (":" ^^^ { interact _ }) + + private val factor = interaction * ("*" ^^^ { cross _ }) + + private val sum = factor * ("+" ^^^ { add _ } | + "-" ^^^ { subtract _ }) + + private val expr = (sum | term) private val formula: Parser[ParsedRFormula] = - (label ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + (label ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.asTerms.terms) } def parse(value: String): ParsedRFormula = parseAll(formula, value) match { case Success(result, _) => result diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index 53798c659d..add1cc17ea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -90,12 +90,48 @@ class RFormulaParserSuite extends SparkFunSuite { test("parse interactions") { checkParse("y ~ a:b", "y", Seq("a:b")) + checkParse("y ~ a:b + b:a", "y", Seq("a:b")) checkParse("y ~ ._a:._x", "y", Seq("._a:._x")) checkParse("y ~ foo:bar", "y", Seq("foo:bar")) checkParse("y ~ a : b : c", "y", Seq("a:b:c")) checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", "c:d", "z")) } + test("parse factor cross") { + checkParse("y ~ a*b", "y", Seq("a", "b", "a:b")) + checkParse("y ~ a*b + b*a", "y", Seq("a", "b", "a:b")) + checkParse("y ~ ._a*._x", "y", Seq("._a", "._x", "._a:._x")) + checkParse("y ~ foo*bar", "y", Seq("foo", "bar", "foo:bar")) + checkParse("y ~ a * b * c", "y", Seq("a", "b", "a:b", "c", "a:c", "b:c", "a:b:c")) + } + + test("interaction distributive") { + checkParse("y ~ (a + b):c", "y", Seq("a:c", "b:c")) + checkParse("y ~ c:(a + b)", "y", Seq("c:a", "c:b")) + } + + test("factor cross distributive") { + checkParse("y ~ (a + b)*c", "y", Seq("a", "b", "c", "a:c", "b:c")) + checkParse("y ~ c*(a + b)", "y", Seq("c", "a", "b", "c:a", "c:b")) + } + + test("parse power") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + .add("d", "string", true) + checkParse("a ~ (a + b)^2", "a", Seq("a", "b", "a:b")) + checkParse("a ~ .^2", "a", Seq("b", "c", "d", "b:c", "b:d", "c:d"), schema) + checkParse("a ~ .^3", "a", Seq("b", "c", "d", "b:c", "b:d", "c:d", "b:c:d"), schema) + checkParse("a ~ .^3-.", "a", Seq("b:c", "b:d", "c:d", "b:c:d"), schema) + } + + test("operator precedence") { + checkParse("y ~ a*b:c", "y", Seq("a", "b:c", "a:b:c")) + checkParse("y ~ (a*b):c", "y", Seq("a:c", "b:c", "a:b:c")) + } + test("parse basic interactions with dot") { val schema = (new StructType) .add("a", "int", true) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 4f5809c37f..9827a2af5a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3423,7 +3423,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, HasHandleInvalid, Implements the transforms required for fitting a dataset against an R model formula. Currently we support a limited subset of the R - operators, including '~', '.', ':', '+', and '-'. Also see the `R formula docs + operators, including '~', '.', ':', '+', '-', '*', and '^'. + Also see the `R formula docs `_. >>> df = spark.createDataFrame([