[SPARK-10740] [SQL] handle nondeterministic expressions correctly for set operations

https://issues.apache.org/jira/browse/SPARK-10740

Author: Wenchen Fan <cloud0fan@163.com>

Closes #8858 from cloud-fan/non-deter.
This commit is contained in:
Wenchen Fan 2015-09-22 12:14:15 -07:00 committed by Yin Huai
parent 1ca5e2e0b8
commit 5017c685f4
3 changed files with 92 additions and 19 deletions

View file

@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] {
* Intersect:
* It is not safe to pushdown Projections through it because we need to get the
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
* because we will not have non-deterministic expressions.
* with deterministic condition.
*
* Except:
* It is not safe to pushdown Projections through it because we need to get the
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
* because we will not have non-deterministic expressions.
* with deterministic condition.
*/
object SetOperationPushDown extends Rule[LogicalPlan] {
object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
/**
* Maps Attributes from the left side to the corresponding Attribute on the right side.
@ -129,34 +129,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] {
result.asInstanceOf[A]
}
/**
* Splits the condition expression into small conditions by `And`, and partition them by
* deterministic, and finally recombine them by `And`. It returns an expression containing
* all deterministic expressions (the first field of the returned Tuple2) and an expression
* containing all non-deterministic expressions (the second field of the returned Tuple2).
*/
private def partitionByDeterministic(condition: Expression): (Expression, Expression) = {
val andConditions = splitConjunctivePredicates(condition)
andConditions.partition(_.deterministic) match {
case (deterministic, nondeterministic) =>
deterministic.reduceOption(And).getOrElse(Literal(true)) ->
nondeterministic.reduceOption(And).getOrElse(Literal(true))
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Push down filter into union
case Filter(condition, u @ Union(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(u)
Union(
Filter(condition, left),
Filter(pushToRight(condition, rewrites), right))
Filter(nondeterministic,
Union(
Filter(deterministic, left),
Filter(pushToRight(deterministic, rewrites), right)
)
)
// Push down projection through UNION ALL
case Project(projectList, u @ Union(left, right)) =>
val rewrites = buildRewrites(u)
Union(
Project(projectList, left),
Project(projectList.map(pushToRight(_, rewrites)), right))
// Push down deterministic projection through UNION ALL
case p @ Project(projectList, u @ Union(left, right)) =>
if (projectList.forall(_.deterministic)) {
val rewrites = buildRewrites(u)
Union(
Project(projectList, left),
Project(projectList.map(pushToRight(_, rewrites)), right))
} else {
p
}
// Push down filter through INTERSECT
case Filter(condition, i @ Intersect(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(i)
Intersect(
Filter(condition, left),
Filter(pushToRight(condition, rewrites), right))
Filter(nondeterministic,
Intersect(
Filter(deterministic, left),
Filter(pushToRight(deterministic, rewrites), right)
)
)
// Push down filter through EXCEPT
case Filter(condition, e @ Except(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(e)
Except(
Filter(condition, left),
Filter(pushToRight(condition, rewrites), right))
Filter(nondeterministic,
Except(
Filter(deterministic, left),
Filter(pushToRight(deterministic, rewrites), right)
)
)
}
}

View file

@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest {
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Union Pushdown", Once,
SetOperationPushDown) :: Nil
SetOperationPushDown,
SimplifyFilters) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

View file

@ -916,4 +916,45 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(intersect.count() === 30)
assert(except.count() === 70)
}
test("SPARK-10740: handle nondeterministic expressions correctly for set operations") {
val df1 = (1 to 20).map(Tuple1.apply).toDF("i")
val df2 = (1 to 10).map(Tuple1.apply).toDF("i")
// When generating expected results at here, we need to follow the implementation of
// Rand expression.
def expected(df: DataFrame): Seq[Row] = {
df.rdd.collectPartitions().zipWithIndex.flatMap {
case (data, index) =>
val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
data.filter(_.getInt(0) < rng.nextDouble() * 10)
}
}
val union = df1.unionAll(df2)
checkAnswer(
union.filter('i < rand(7) * 10),
expected(union)
)
checkAnswer(
union.select(rand(7)),
union.rdd.collectPartitions().zipWithIndex.flatMap {
case (data, index) =>
val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
data.map(_ => rng.nextDouble()).map(i => Row(i))
}
)
val intersect = df1.intersect(df2)
checkAnswer(
intersect.filter('i < rand(7) * 10),
expected(intersect)
)
val except = df1.except(df2)
checkAnswer(
except.filter('i < rand(7) * 10),
expected(except)
)
}
}