[SPARK-5579][SQL][DataFrame] Support for project/filter using SQL expressions
```scala df.selectExpr("abs(colA)", "colB") df.filter("age > 21") ``` Author: Reynold Xin <rxin@databricks.com> Closes #4348 from rxin/SPARK-5579 and squashes the following commits: 2baeef2 [Reynold Xin] Fix Python. b416372 [Reynold Xin] [SPARK-5579][SQL][DataFrame] Support for project/filter using SQL expressions.
This commit is contained in:
parent
eb15631854
commit
40c4cb2fe7
|
@ -2126,10 +2126,9 @@ class DataFrame(object):
|
|||
"""
|
||||
if not cols:
|
||||
raise ValueError("should sort by at least one column")
|
||||
jcols = ListConverter().convert([_to_java_column(c) for c in cols[1:]],
|
||||
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
|
||||
self._sc._gateway._gateway_client)
|
||||
jdf = self._jdf.sort(_to_java_column(cols[0]),
|
||||
self._sc._jvm.Dsl.toColumns(jcols))
|
||||
jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols))
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
sortBy = sort
|
||||
|
|
|
@ -36,6 +36,16 @@ import org.apache.spark.sql.types._
|
|||
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
|
||||
*/
|
||||
class SqlParser extends AbstractSparkSQLParser {
|
||||
|
||||
def parseExpression(input: String): Expression = {
|
||||
// Initialize the Keywords.
|
||||
lexical.initialize(reservedWords)
|
||||
phrase(expression)(new lexical.Scanner(input)) match {
|
||||
case Success(plan, _) => plan
|
||||
case failureOrError => sys.error(failureOrError.toString)
|
||||
}
|
||||
}
|
||||
|
||||
// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
|
||||
// properties via reflection the class in runtime for constructing the SqlLexical object
|
||||
protected val ABS = Keyword("ABS")
|
||||
|
|
|
@ -173,7 +173,7 @@ trait DataFrame extends RDDApi[Row] {
|
|||
* }}}
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def sort(sortExpr: Column, sortExprs: Column*): DataFrame
|
||||
def sort(sortExprs: Column*): DataFrame
|
||||
|
||||
/**
|
||||
* Returns a new [[DataFrame]] sorted by the given expressions.
|
||||
|
@ -187,7 +187,7 @@ trait DataFrame extends RDDApi[Row] {
|
|||
* This is an alias of the `sort` function.
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
|
||||
def orderBy(sortExprs: Column*): DataFrame
|
||||
|
||||
/**
|
||||
* Selects column based on the column name and return it as a [[Column]].
|
||||
|
@ -236,6 +236,17 @@ trait DataFrame extends RDDApi[Row] {
|
|||
@scala.annotation.varargs
|
||||
def select(col: String, cols: String*): DataFrame
|
||||
|
||||
/**
|
||||
* Selects a set of SQL expressions. This is a variant of `select` that accepts
|
||||
* SQL expressions.
|
||||
*
|
||||
* {{{
|
||||
* df.selectExpr("colA", "colB as newName", "abs(colC)")
|
||||
* }}}
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def selectExpr(exprs: String*): DataFrame
|
||||
|
||||
/**
|
||||
* Filters rows using the given condition.
|
||||
* {{{
|
||||
|
@ -247,6 +258,14 @@ trait DataFrame extends RDDApi[Row] {
|
|||
*/
|
||||
def filter(condition: Column): DataFrame
|
||||
|
||||
/**
|
||||
* Filters rows using the given SQL expression.
|
||||
* {{{
|
||||
* peopleDf.filter("age > 15")
|
||||
* }}}
|
||||
*/
|
||||
def filter(conditionExpr: String): DataFrame
|
||||
|
||||
/**
|
||||
* Filters rows using the given condition. This is an alias for `filter`.
|
||||
* {{{
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD
|
|||
import org.apache.spark.api.python.SerDeUtil
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
|
||||
|
@ -124,11 +124,11 @@ private[sql] class DataFrameImpl protected[sql](
|
|||
}
|
||||
|
||||
override def sort(sortCol: String, sortCols: String*): DataFrame = {
|
||||
orderBy(apply(sortCol), sortCols.map(apply) :_*)
|
||||
sort((sortCol +: sortCols).map(apply) :_*)
|
||||
}
|
||||
|
||||
override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = {
|
||||
val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col =>
|
||||
override def sort(sortExprs: Column*): DataFrame = {
|
||||
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
|
||||
col.expr match {
|
||||
case expr: SortOrder =>
|
||||
expr
|
||||
|
@ -143,8 +143,8 @@ private[sql] class DataFrameImpl protected[sql](
|
|||
sort(sortCol, sortCols :_*)
|
||||
}
|
||||
|
||||
override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
|
||||
sort(sortExpr, sortExprs :_*)
|
||||
override def orderBy(sortExprs: Column*): DataFrame = {
|
||||
sort(sortExprs :_*)
|
||||
}
|
||||
|
||||
override def col(colName: String): Column = colName match {
|
||||
|
@ -179,10 +179,20 @@ private[sql] class DataFrameImpl protected[sql](
|
|||
select((col +: cols).map(Column(_)) :_*)
|
||||
}
|
||||
|
||||
override def selectExpr(exprs: String*): DataFrame = {
|
||||
select(exprs.map { expr =>
|
||||
Column(new SqlParser().parseExpression(expr))
|
||||
} :_*)
|
||||
}
|
||||
|
||||
override def filter(condition: Column): DataFrame = {
|
||||
Filter(condition.expr, logicalPlan)
|
||||
}
|
||||
|
||||
override def filter(conditionExpr: String): DataFrame = {
|
||||
filter(Column(new SqlParser().parseExpression(conditionExpr)))
|
||||
}
|
||||
|
||||
override def where(condition: Column): DataFrame = {
|
||||
filter(condition)
|
||||
}
|
||||
|
|
|
@ -66,11 +66,11 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
|
|||
|
||||
override def sort(sortCol: String, sortCols: String*): DataFrame = err()
|
||||
|
||||
override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = err()
|
||||
override def sort(sortExprs: Column*): DataFrame = err()
|
||||
|
||||
override def orderBy(sortCol: String, sortCols: String*): DataFrame = err()
|
||||
|
||||
override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = err()
|
||||
override def orderBy(sortExprs: Column*): DataFrame = err()
|
||||
|
||||
override def col(colName: String): Column = err()
|
||||
|
||||
|
@ -80,8 +80,12 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
|
|||
|
||||
override def select(col: String, cols: String*): DataFrame = err()
|
||||
|
||||
override def selectExpr(exprs: String*): DataFrame = err()
|
||||
|
||||
override def filter(condition: Column): DataFrame = err()
|
||||
|
||||
override def filter(conditionExpr: String): DataFrame = err()
|
||||
|
||||
override def where(condition: Column): DataFrame = err()
|
||||
|
||||
override def apply(condition: Column): DataFrame = err()
|
||||
|
|
|
@ -47,6 +47,18 @@ class DataFrameSuite extends QueryTest {
|
|||
testData.collect().toSeq)
|
||||
}
|
||||
|
||||
test("selectExpr") {
|
||||
checkAnswer(
|
||||
testData.selectExpr("abs(key)", "value"),
|
||||
testData.collect().map(row => Row(math.abs(row.getInt(0)), row.getString(1))).toSeq)
|
||||
}
|
||||
|
||||
test("filterExpr") {
|
||||
checkAnswer(
|
||||
testData.filter("key > 90"),
|
||||
testData.collect().filter(_.getInt(0) > 90).toSeq)
|
||||
}
|
||||
|
||||
test("repartition") {
|
||||
checkAnswer(
|
||||
testData.select('key).repartition(10).select('key),
|
||||
|
|
Loading…
Reference in a new issue