[SPARK-5579][SQL][DataFrame] Support for project/filter using SQL expressions

```scala
df.selectExpr("abs(colA)", "colB")
df.filter("age > 21")
```

Author: Reynold Xin <rxin@databricks.com>

Closes #4348 from rxin/SPARK-5579 and squashes the following commits:

2baeef2 [Reynold Xin] Fix Python.
b416372 [Reynold Xin] [SPARK-5579][SQL][DataFrame] Support for project/filter using SQL expressions.
This commit is contained in:
Reynold Xin 2015-02-03 22:15:35 -08:00
parent eb15631854
commit 40c4cb2fe7
6 changed files with 67 additions and 13 deletions

View file

@ -2126,10 +2126,9 @@ class DataFrame(object):
"""
if not cols:
raise ValueError("should sort by at least one column")
jcols = ListConverter().convert([_to_java_column(c) for c in cols[1:]],
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.sort(_to_java_column(cols[0]),
self._sc._jvm.Dsl.toColumns(jcols))
jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)
sortBy = sort

View file

@ -36,6 +36,16 @@ import org.apache.spark.sql.types._
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
class SqlParser extends AbstractSparkSQLParser {
def parseExpression(input: String): Expression = {
// Initialize the Keywords.
lexical.initialize(reservedWords)
phrase(expression)(new lexical.Scanner(input)) match {
case Success(plan, _) => plan
case failureOrError => sys.error(failureOrError.toString)
}
}
// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
// properties via reflection the class in runtime for constructing the SqlLexical object
protected val ABS = Keyword("ABS")

View file

@ -173,7 +173,7 @@ trait DataFrame extends RDDApi[Row] {
* }}}
*/
@scala.annotation.varargs
def sort(sortExpr: Column, sortExprs: Column*): DataFrame
def sort(sortExprs: Column*): DataFrame
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
@ -187,7 +187,7 @@ trait DataFrame extends RDDApi[Row] {
* This is an alias of the `sort` function.
*/
@scala.annotation.varargs
def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
def orderBy(sortExprs: Column*): DataFrame
/**
* Selects column based on the column name and return it as a [[Column]].
@ -236,6 +236,17 @@ trait DataFrame extends RDDApi[Row] {
@scala.annotation.varargs
def select(col: String, cols: String*): DataFrame
/**
* Selects a set of SQL expressions. This is a variant of `select` that accepts
* SQL expressions.
*
* {{{
* df.selectExpr("colA", "colB as newName", "abs(colC)")
* }}}
*/
@scala.annotation.varargs
def selectExpr(exprs: String*): DataFrame
/**
* Filters rows using the given condition.
* {{{
@ -247,6 +258,14 @@ trait DataFrame extends RDDApi[Row] {
*/
def filter(condition: Column): DataFrame
/**
* Filters rows using the given SQL expression.
* {{{
* peopleDf.filter("age > 15")
* }}}
*/
def filter(conditionExpr: String): DataFrame
/**
* Filters rows using the given condition. This is an alias for `filter`.
* {{{

View file

@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@ -124,11 +124,11 @@ private[sql] class DataFrameImpl protected[sql](
}
override def sort(sortCol: String, sortCols: String*): DataFrame = {
orderBy(apply(sortCol), sortCols.map(apply) :_*)
sort((sortCol +: sortCols).map(apply) :_*)
}
override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = {
val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col =>
override def sort(sortExprs: Column*): DataFrame = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
@ -143,8 +143,8 @@ private[sql] class DataFrameImpl protected[sql](
sort(sortCol, sortCols :_*)
}
override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
sort(sortExpr, sortExprs :_*)
override def orderBy(sortExprs: Column*): DataFrame = {
sort(sortExprs :_*)
}
override def col(colName: String): Column = colName match {
@ -179,10 +179,20 @@ private[sql] class DataFrameImpl protected[sql](
select((col +: cols).map(Column(_)) :_*)
}
override def selectExpr(exprs: String*): DataFrame = {
select(exprs.map { expr =>
Column(new SqlParser().parseExpression(expr))
} :_*)
}
override def filter(condition: Column): DataFrame = {
Filter(condition.expr, logicalPlan)
}
override def filter(conditionExpr: String): DataFrame = {
filter(Column(new SqlParser().parseExpression(conditionExpr)))
}
override def where(condition: Column): DataFrame = {
filter(condition)
}

View file

@ -66,11 +66,11 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def sort(sortCol: String, sortCols: String*): DataFrame = err()
override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = err()
override def sort(sortExprs: Column*): DataFrame = err()
override def orderBy(sortCol: String, sortCols: String*): DataFrame = err()
override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = err()
override def orderBy(sortExprs: Column*): DataFrame = err()
override def col(colName: String): Column = err()
@ -80,8 +80,12 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def select(col: String, cols: String*): DataFrame = err()
override def selectExpr(exprs: String*): DataFrame = err()
override def filter(condition: Column): DataFrame = err()
override def filter(conditionExpr: String): DataFrame = err()
override def where(condition: Column): DataFrame = err()
override def apply(condition: Column): DataFrame = err()

View file

@ -47,6 +47,18 @@ class DataFrameSuite extends QueryTest {
testData.collect().toSeq)
}
test("selectExpr") {
checkAnswer(
testData.selectExpr("abs(key)", "value"),
testData.collect().map(row => Row(math.abs(row.getInt(0)), row.getString(1))).toSeq)
}
test("filterExpr") {
checkAnswer(
testData.filter("key > 90"),
testData.collect().filter(_.getInt(0) > 90).toSeq)
}
test("repartition") {
checkAnswer(
testData.select('key).repartition(10).select('key),