[SPARK-10740] [SQL] handle nondeterministic expressions correctly for set operations
https://issues.apache.org/jira/browse/SPARK-10740 Author: Wenchen Fan <cloud0fan@163.com> Closes #8858 from cloud-fan/non-deter.
This commit is contained in:
parent
1ca5e2e0b8
commit
5017c685f4
|
@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] {
|
|||
* Intersect:
|
||||
* It is not safe to pushdown Projections through it because we need to get the
|
||||
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
|
||||
* because we will not have non-deterministic expressions.
|
||||
* with deterministic condition.
|
||||
*
|
||||
* Except:
|
||||
* It is not safe to pushdown Projections through it because we need to get the
|
||||
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
|
||||
* because we will not have non-deterministic expressions.
|
||||
* with deterministic condition.
|
||||
*/
|
||||
object SetOperationPushDown extends Rule[LogicalPlan] {
|
||||
object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
|
||||
|
||||
/**
|
||||
* Maps Attributes from the left side to the corresponding Attribute on the right side.
|
||||
|
@ -129,34 +129,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] {
|
|||
result.asInstanceOf[A]
|
||||
}
|
||||
|
||||
/**
|
||||
* Splits the condition expression into small conditions by `And`, and partition them by
|
||||
* deterministic, and finally recombine them by `And`. It returns an expression containing
|
||||
* all deterministic expressions (the first field of the returned Tuple2) and an expression
|
||||
* containing all non-deterministic expressions (the second field of the returned Tuple2).
|
||||
*/
|
||||
private def partitionByDeterministic(condition: Expression): (Expression, Expression) = {
|
||||
val andConditions = splitConjunctivePredicates(condition)
|
||||
andConditions.partition(_.deterministic) match {
|
||||
case (deterministic, nondeterministic) =>
|
||||
deterministic.reduceOption(And).getOrElse(Literal(true)) ->
|
||||
nondeterministic.reduceOption(And).getOrElse(Literal(true))
|
||||
}
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
// Push down filter into union
|
||||
case Filter(condition, u @ Union(left, right)) =>
|
||||
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
|
||||
val rewrites = buildRewrites(u)
|
||||
Union(
|
||||
Filter(condition, left),
|
||||
Filter(pushToRight(condition, rewrites), right))
|
||||
Filter(nondeterministic,
|
||||
Union(
|
||||
Filter(deterministic, left),
|
||||
Filter(pushToRight(deterministic, rewrites), right)
|
||||
)
|
||||
)
|
||||
|
||||
// Push down projection through UNION ALL
|
||||
case Project(projectList, u @ Union(left, right)) =>
|
||||
val rewrites = buildRewrites(u)
|
||||
Union(
|
||||
Project(projectList, left),
|
||||
Project(projectList.map(pushToRight(_, rewrites)), right))
|
||||
// Push down deterministic projection through UNION ALL
|
||||
case p @ Project(projectList, u @ Union(left, right)) =>
|
||||
if (projectList.forall(_.deterministic)) {
|
||||
val rewrites = buildRewrites(u)
|
||||
Union(
|
||||
Project(projectList, left),
|
||||
Project(projectList.map(pushToRight(_, rewrites)), right))
|
||||
} else {
|
||||
p
|
||||
}
|
||||
|
||||
// Push down filter through INTERSECT
|
||||
case Filter(condition, i @ Intersect(left, right)) =>
|
||||
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
|
||||
val rewrites = buildRewrites(i)
|
||||
Intersect(
|
||||
Filter(condition, left),
|
||||
Filter(pushToRight(condition, rewrites), right))
|
||||
Filter(nondeterministic,
|
||||
Intersect(
|
||||
Filter(deterministic, left),
|
||||
Filter(pushToRight(deterministic, rewrites), right)
|
||||
)
|
||||
)
|
||||
|
||||
// Push down filter through EXCEPT
|
||||
case Filter(condition, e @ Except(left, right)) =>
|
||||
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
|
||||
val rewrites = buildRewrites(e)
|
||||
Except(
|
||||
Filter(condition, left),
|
||||
Filter(pushToRight(condition, rewrites), right))
|
||||
Filter(nondeterministic,
|
||||
Except(
|
||||
Filter(deterministic, left),
|
||||
Filter(pushToRight(deterministic, rewrites), right)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest {
|
|||
Batch("Subqueries", Once,
|
||||
EliminateSubQueries) ::
|
||||
Batch("Union Pushdown", Once,
|
||||
SetOperationPushDown) :: Nil
|
||||
SetOperationPushDown,
|
||||
SimplifyFilters) :: Nil
|
||||
}
|
||||
|
||||
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
|
|
|
@ -916,4 +916,45 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
|
|||
assert(intersect.count() === 30)
|
||||
assert(except.count() === 70)
|
||||
}
|
||||
|
||||
test("SPARK-10740: handle nondeterministic expressions correctly for set operations") {
|
||||
val df1 = (1 to 20).map(Tuple1.apply).toDF("i")
|
||||
val df2 = (1 to 10).map(Tuple1.apply).toDF("i")
|
||||
|
||||
// When generating expected results at here, we need to follow the implementation of
|
||||
// Rand expression.
|
||||
def expected(df: DataFrame): Seq[Row] = {
|
||||
df.rdd.collectPartitions().zipWithIndex.flatMap {
|
||||
case (data, index) =>
|
||||
val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
|
||||
data.filter(_.getInt(0) < rng.nextDouble() * 10)
|
||||
}
|
||||
}
|
||||
|
||||
val union = df1.unionAll(df2)
|
||||
checkAnswer(
|
||||
union.filter('i < rand(7) * 10),
|
||||
expected(union)
|
||||
)
|
||||
checkAnswer(
|
||||
union.select(rand(7)),
|
||||
union.rdd.collectPartitions().zipWithIndex.flatMap {
|
||||
case (data, index) =>
|
||||
val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
|
||||
data.map(_ => rng.nextDouble()).map(i => Row(i))
|
||||
}
|
||||
)
|
||||
|
||||
val intersect = df1.intersect(df2)
|
||||
checkAnswer(
|
||||
intersect.filter('i < rand(7) * 10),
|
||||
expected(intersect)
|
||||
)
|
||||
|
||||
val except = df1.except(df2)
|
||||
checkAnswer(
|
||||
except.filter('i < rand(7) * 10),
|
||||
expected(except)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue