[SPARK-2588][SQL] Add some more DSLs.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #1491 from ueshin/issues/SPARK-2588 and squashes the following commits:

43d0a46 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-2588
1023ea0 [Takuya UESHIN] Modify tests to use DSLs.
2310bf1 [Takuya UESHIN] Add some more DSLs.
This commit is contained in:
Takuya UESHIN 2014-07-23 14:47:23 -07:00 committed by Michael Armbrust
parent f776bc9887
commit 1b790cf775
3 changed files with 70 additions and 33 deletions

View file

@ -79,8 +79,24 @@ package object dsl {
def === (other: Expression) = EqualTo(expr, other)
def !== (other: Expression) = Not(EqualTo(expr, other))
def in(list: Expression*) = In(expr, list)
def like(other: Expression) = Like(expr, other)
def rlike(other: Expression) = RLike(expr, other)
def contains(other: Expression) = Contains(expr, other)
def startsWith(other: Expression) = StartsWith(expr, other)
def endsWith(other: Expression) = EndsWith(expr, other)
def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
Substring(expr, pos, len)
def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
Substring(expr, pos, len)
def isNull = IsNull(expr)
def isNotNull = IsNotNull(expr)
def getItem(ordinal: Expression) = GetItem(expr, ordinal)
def getField(fieldName: String) = GetField(expr, fieldName)
def cast(to: DataType) = Cast(expr, to)
def asc = SortOrder(expr, Ascending)
@ -112,6 +128,7 @@ package object dsl {
def sumDistinct(e: Expression) = SumDistinct(e)
def count(e: Expression) = Count(e)
def countDistinct(e: Expression*) = CountDistinct(e)
def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd)
def avg(e: Expression) = Average(e)
def first(e: Expression) = First(e)
def min(e: Expression) = Min(e)
@ -163,6 +180,18 @@ package object dsl {
/** Creates a new AttributeReference of type binary */
def binary = AttributeReference(s, BinaryType, nullable = true)()
/** Creates a new AttributeReference of type array */
def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)()
/** Creates a new AttributeReference of type map */
def map(keyType: DataType, valueType: DataType): AttributeReference =
map(MapType(keyType, valueType))
def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)()
/** Creates a new AttributeReference of type struct */
def struct(fields: StructField*): AttributeReference = struct(StructType(fields))
def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)()
}
implicit class DslAttribute(a: AttributeReference) {

View file

@ -301,17 +301,17 @@ class ExpressionEvaluationSuite extends FunSuite {
val c3 = 'a.boolean.at(2)
val c4 = 'a.boolean.at(3)
checkEvaluation(IsNull(c1), false, row)
checkEvaluation(IsNotNull(c1), true, row)
checkEvaluation(c1.isNull, false, row)
checkEvaluation(c1.isNotNull, true, row)
checkEvaluation(IsNull(c2), true, row)
checkEvaluation(IsNotNull(c2), false, row)
checkEvaluation(c2.isNull, true, row)
checkEvaluation(c2.isNotNull, false, row)
checkEvaluation(IsNull(Literal(1, ShortType)), false)
checkEvaluation(IsNotNull(Literal(1, ShortType)), true)
checkEvaluation(Literal(1, ShortType).isNull, false)
checkEvaluation(Literal(1, ShortType).isNotNull, true)
checkEvaluation(IsNull(Literal(null, ShortType)), true)
checkEvaluation(IsNotNull(Literal(null, ShortType)), false)
checkEvaluation(Literal(null, ShortType).isNull, true)
checkEvaluation(Literal(null, ShortType).isNotNull, false)
checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
@ -326,11 +326,11 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(If(Literal(false, BooleanType),
Literal("a", StringType), Literal("b", StringType)), "b", row)
checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
checkEvaluation(c1 in (c1, c2), true, row)
checkEvaluation(
Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row)
checkEvaluation(
Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row)
}
test("case when") {
@ -420,6 +420,10 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(GetField(Literal(null, typeS), "a").nullable === true)
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)
checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row)
}
test("arithmetic") {
@ -472,20 +476,20 @@ class ExpressionEvaluationSuite extends FunSuite {
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
checkEvaluation(Contains(c1, "b"), true, row)
checkEvaluation(Contains(c1, "x"), false, row)
checkEvaluation(Contains(c2, "b"), null, row)
checkEvaluation(Contains(c1, Literal(null, StringType)), null, row)
checkEvaluation(c1 contains "b", true, row)
checkEvaluation(c1 contains "x", false, row)
checkEvaluation(c2 contains "b", null, row)
checkEvaluation(c1 contains Literal(null, StringType), null, row)
checkEvaluation(StartsWith(c1, "a"), true, row)
checkEvaluation(StartsWith(c1, "b"), false, row)
checkEvaluation(StartsWith(c2, "a"), null, row)
checkEvaluation(StartsWith(c1, Literal(null, StringType)), null, row)
checkEvaluation(c1 startsWith "a", true, row)
checkEvaluation(c1 startsWith "b", false, row)
checkEvaluation(c2 startsWith "a", null, row)
checkEvaluation(c1 startsWith Literal(null, StringType), null, row)
checkEvaluation(EndsWith(c1, "c"), true, row)
checkEvaluation(EndsWith(c1, "b"), false, row)
checkEvaluation(EndsWith(c2, "b"), null, row)
checkEvaluation(EndsWith(c1, Literal(null, StringType)), null, row)
checkEvaluation(c1 endsWith "c", true, row)
checkEvaluation(c1 endsWith "b", false, row)
checkEvaluation(c2 endsWith "b", null, row)
checkEvaluation(c1 endsWith Literal(null, StringType), null, row)
}
test("Substring") {
@ -542,5 +546,10 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false)
assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true)
assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true)
checkEvaluation(s.substr(0, 2), "ex", row)
checkEvaluation(s.substr(0), "example", row)
checkEvaluation(s.substring(0, 2), "ex", row)
checkEvaluation(s.substring(0), "example", row)
}
}

View file

@ -18,7 +18,6 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test._
/* Implicits */
@ -41,15 +40,15 @@ class DslQuerySuite extends QueryTest {
test("agg") {
checkAnswer(
testData2.groupBy('a)('a, Sum('b)),
testData2.groupBy('a)('a, sum('b)),
Seq((1,3),(2,3),(3,3))
)
checkAnswer(
testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)),
testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
9
)
checkAnswer(
testData2.aggregate(Sum('b)),
testData2.aggregate(sum('b)),
9
)
}
@ -104,19 +103,19 @@ class DslQuerySuite extends QueryTest {
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
checkAnswer(
arrayData.orderBy(GetItem('data, 0).asc),
arrayData.orderBy('data.getItem(0).asc),
arrayData.collect().sortBy(_.data(0)).toSeq)
checkAnswer(
arrayData.orderBy(GetItem('data, 0).desc),
arrayData.orderBy('data.getItem(0).desc),
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
checkAnswer(
mapData.orderBy(GetItem('data, 1).asc),
mapData.orderBy('data.getItem(1).asc),
mapData.collect().sortBy(_.data(1)).toSeq)
checkAnswer(
mapData.orderBy(GetItem('data, 1).desc),
mapData.orderBy('data.getItem(1).desc),
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
}