[SPARK-2588][SQL] Add some more DSLs.
Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #1491 from ueshin/issues/SPARK-2588 and squashes the following commits: 43d0a46 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-2588 1023ea0 [Takuya UESHIN] Modify tests to use DSLs. 2310bf1 [Takuya UESHIN] Add some more DSLs.
This commit is contained in:
parent
f776bc9887
commit
1b790cf775
|
@ -79,8 +79,24 @@ package object dsl {
|
|||
def === (other: Expression) = EqualTo(expr, other)
|
||||
def !== (other: Expression) = Not(EqualTo(expr, other))
|
||||
|
||||
def in(list: Expression*) = In(expr, list)
|
||||
|
||||
def like(other: Expression) = Like(expr, other)
|
||||
def rlike(other: Expression) = RLike(expr, other)
|
||||
def contains(other: Expression) = Contains(expr, other)
|
||||
def startsWith(other: Expression) = StartsWith(expr, other)
|
||||
def endsWith(other: Expression) = EndsWith(expr, other)
|
||||
def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
|
||||
Substring(expr, pos, len)
|
||||
def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
|
||||
Substring(expr, pos, len)
|
||||
|
||||
def isNull = IsNull(expr)
|
||||
def isNotNull = IsNotNull(expr)
|
||||
|
||||
def getItem(ordinal: Expression) = GetItem(expr, ordinal)
|
||||
def getField(fieldName: String) = GetField(expr, fieldName)
|
||||
|
||||
def cast(to: DataType) = Cast(expr, to)
|
||||
|
||||
def asc = SortOrder(expr, Ascending)
|
||||
|
@ -112,6 +128,7 @@ package object dsl {
|
|||
def sumDistinct(e: Expression) = SumDistinct(e)
|
||||
def count(e: Expression) = Count(e)
|
||||
def countDistinct(e: Expression*) = CountDistinct(e)
|
||||
def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd)
|
||||
def avg(e: Expression) = Average(e)
|
||||
def first(e: Expression) = First(e)
|
||||
def min(e: Expression) = Min(e)
|
||||
|
@ -163,6 +180,18 @@ package object dsl {
|
|||
|
||||
/** Creates a new AttributeReference of type binary */
|
||||
def binary = AttributeReference(s, BinaryType, nullable = true)()
|
||||
|
||||
/** Creates a new AttributeReference of type array */
|
||||
def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)()
|
||||
|
||||
/** Creates a new AttributeReference of type map */
|
||||
def map(keyType: DataType, valueType: DataType): AttributeReference =
|
||||
map(MapType(keyType, valueType))
|
||||
def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)()
|
||||
|
||||
/** Creates a new AttributeReference of type struct */
|
||||
def struct(fields: StructField*): AttributeReference = struct(StructType(fields))
|
||||
def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)()
|
||||
}
|
||||
|
||||
implicit class DslAttribute(a: AttributeReference) {
|
||||
|
|
|
@ -301,17 +301,17 @@ class ExpressionEvaluationSuite extends FunSuite {
|
|||
val c3 = 'a.boolean.at(2)
|
||||
val c4 = 'a.boolean.at(3)
|
||||
|
||||
checkEvaluation(IsNull(c1), false, row)
|
||||
checkEvaluation(IsNotNull(c1), true, row)
|
||||
checkEvaluation(c1.isNull, false, row)
|
||||
checkEvaluation(c1.isNotNull, true, row)
|
||||
|
||||
checkEvaluation(IsNull(c2), true, row)
|
||||
checkEvaluation(IsNotNull(c2), false, row)
|
||||
checkEvaluation(c2.isNull, true, row)
|
||||
checkEvaluation(c2.isNotNull, false, row)
|
||||
|
||||
checkEvaluation(IsNull(Literal(1, ShortType)), false)
|
||||
checkEvaluation(IsNotNull(Literal(1, ShortType)), true)
|
||||
checkEvaluation(Literal(1, ShortType).isNull, false)
|
||||
checkEvaluation(Literal(1, ShortType).isNotNull, true)
|
||||
|
||||
checkEvaluation(IsNull(Literal(null, ShortType)), true)
|
||||
checkEvaluation(IsNotNull(Literal(null, ShortType)), false)
|
||||
checkEvaluation(Literal(null, ShortType).isNull, true)
|
||||
checkEvaluation(Literal(null, ShortType).isNotNull, false)
|
||||
|
||||
checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
|
||||
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
|
||||
|
@ -326,11 +326,11 @@ class ExpressionEvaluationSuite extends FunSuite {
|
|||
checkEvaluation(If(Literal(false, BooleanType),
|
||||
Literal("a", StringType), Literal("b", StringType)), "b", row)
|
||||
|
||||
checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
|
||||
checkEvaluation(In(Literal("^Ba*n", StringType),
|
||||
Literal("^Ba*n", StringType) :: Nil), true, row)
|
||||
checkEvaluation(In(Literal("^Ba*n", StringType),
|
||||
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
|
||||
checkEvaluation(c1 in (c1, c2), true, row)
|
||||
checkEvaluation(
|
||||
Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row)
|
||||
checkEvaluation(
|
||||
Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row)
|
||||
}
|
||||
|
||||
test("case when") {
|
||||
|
@ -420,6 +420,10 @@ class ExpressionEvaluationSuite extends FunSuite {
|
|||
|
||||
assert(GetField(Literal(null, typeS), "a").nullable === true)
|
||||
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)
|
||||
|
||||
checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
|
||||
checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
|
||||
checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row)
|
||||
}
|
||||
|
||||
test("arithmetic") {
|
||||
|
@ -472,20 +476,20 @@ class ExpressionEvaluationSuite extends FunSuite {
|
|||
val c1 = 'a.string.at(0)
|
||||
val c2 = 'a.string.at(1)
|
||||
|
||||
checkEvaluation(Contains(c1, "b"), true, row)
|
||||
checkEvaluation(Contains(c1, "x"), false, row)
|
||||
checkEvaluation(Contains(c2, "b"), null, row)
|
||||
checkEvaluation(Contains(c1, Literal(null, StringType)), null, row)
|
||||
checkEvaluation(c1 contains "b", true, row)
|
||||
checkEvaluation(c1 contains "x", false, row)
|
||||
checkEvaluation(c2 contains "b", null, row)
|
||||
checkEvaluation(c1 contains Literal(null, StringType), null, row)
|
||||
|
||||
checkEvaluation(StartsWith(c1, "a"), true, row)
|
||||
checkEvaluation(StartsWith(c1, "b"), false, row)
|
||||
checkEvaluation(StartsWith(c2, "a"), null, row)
|
||||
checkEvaluation(StartsWith(c1, Literal(null, StringType)), null, row)
|
||||
checkEvaluation(c1 startsWith "a", true, row)
|
||||
checkEvaluation(c1 startsWith "b", false, row)
|
||||
checkEvaluation(c2 startsWith "a", null, row)
|
||||
checkEvaluation(c1 startsWith Literal(null, StringType), null, row)
|
||||
|
||||
checkEvaluation(EndsWith(c1, "c"), true, row)
|
||||
checkEvaluation(EndsWith(c1, "b"), false, row)
|
||||
checkEvaluation(EndsWith(c2, "b"), null, row)
|
||||
checkEvaluation(EndsWith(c1, Literal(null, StringType)), null, row)
|
||||
checkEvaluation(c1 endsWith "c", true, row)
|
||||
checkEvaluation(c1 endsWith "b", false, row)
|
||||
checkEvaluation(c2 endsWith "b", null, row)
|
||||
checkEvaluation(c1 endsWith Literal(null, StringType), null, row)
|
||||
}
|
||||
|
||||
test("Substring") {
|
||||
|
@ -542,5 +546,10 @@ class ExpressionEvaluationSuite extends FunSuite {
|
|||
assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false)
|
||||
assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true)
|
||||
assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true)
|
||||
|
||||
checkEvaluation(s.substr(0, 2), "ex", row)
|
||||
checkEvaluation(s.substr(0), "example", row)
|
||||
checkEvaluation(s.substring(0, 2), "ex", row)
|
||||
checkEvaluation(s.substring(0), "example", row)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.test._
|
||||
|
||||
/* Implicits */
|
||||
|
@ -41,15 +40,15 @@ class DslQuerySuite extends QueryTest {
|
|||
|
||||
test("agg") {
|
||||
checkAnswer(
|
||||
testData2.groupBy('a)('a, Sum('b)),
|
||||
testData2.groupBy('a)('a, sum('b)),
|
||||
Seq((1,3),(2,3),(3,3))
|
||||
)
|
||||
checkAnswer(
|
||||
testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)),
|
||||
testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
|
||||
9
|
||||
)
|
||||
checkAnswer(
|
||||
testData2.aggregate(Sum('b)),
|
||||
testData2.aggregate(sum('b)),
|
||||
9
|
||||
)
|
||||
}
|
||||
|
@ -104,19 +103,19 @@ class DslQuerySuite extends QueryTest {
|
|||
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
|
||||
|
||||
checkAnswer(
|
||||
arrayData.orderBy(GetItem('data, 0).asc),
|
||||
arrayData.orderBy('data.getItem(0).asc),
|
||||
arrayData.collect().sortBy(_.data(0)).toSeq)
|
||||
|
||||
checkAnswer(
|
||||
arrayData.orderBy(GetItem('data, 0).desc),
|
||||
arrayData.orderBy('data.getItem(0).desc),
|
||||
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
|
||||
|
||||
checkAnswer(
|
||||
mapData.orderBy(GetItem('data, 1).asc),
|
||||
mapData.orderBy('data.getItem(1).asc),
|
||||
mapData.collect().sortBy(_.data(1)).toSeq)
|
||||
|
||||
checkAnswer(
|
||||
mapData.orderBy(GetItem('data, 1).desc),
|
||||
mapData.orderBy('data.getItem(1).desc),
|
||||
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue