[SPARK-9117] [SQL] fix BooleanSimplification in case-insensitive
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7452 from cloud-fan/boolean-simplify and squashes the following commits: 2a6e692 [Wenchen Fan] fix style d3cfd26 [Wenchen Fan] fix BooleanSimplification in case-insensitive
This commit is contained in:
parent
fd6b3101fb
commit
bd903ee89f
|
@ -393,26 +393,26 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
|
||||||
// (a || b) && (a || c) => a || (b && c)
|
// (a || b) && (a || c) => a || (b && c)
|
||||||
case _ =>
|
case _ =>
|
||||||
// 1. Split left and right to get the disjunctive predicates,
|
// 1. Split left and right to get the disjunctive predicates,
|
||||||
// i.e. lhsSet = (a, b), rhsSet = (a, c)
|
// i.e. lhs = (a, b), rhs = (a, c)
|
||||||
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
|
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
|
||||||
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
|
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
|
||||||
// 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
|
// 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
|
||||||
val lhsSet = splitDisjunctivePredicates(left).toSet
|
val lhs = splitDisjunctivePredicates(left)
|
||||||
val rhsSet = splitDisjunctivePredicates(right).toSet
|
val rhs = splitDisjunctivePredicates(right)
|
||||||
val common = lhsSet.intersect(rhsSet)
|
val common = lhs.filter(e => rhs.exists(e.semanticEquals(_)))
|
||||||
if (common.isEmpty) {
|
if (common.isEmpty) {
|
||||||
// No common factors, return the original predicate
|
// No common factors, return the original predicate
|
||||||
and
|
and
|
||||||
} else {
|
} else {
|
||||||
val ldiff = lhsSet.diff(common)
|
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_)))
|
||||||
val rdiff = rhsSet.diff(common)
|
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_)))
|
||||||
if (ldiff.isEmpty || rdiff.isEmpty) {
|
if (ldiff.isEmpty || rdiff.isEmpty) {
|
||||||
// (a || b || c || ...) && (a || b) => (a || b)
|
// (a || b || c || ...) && (a || b) => (a || b)
|
||||||
common.reduce(Or)
|
common.reduce(Or)
|
||||||
} else {
|
} else {
|
||||||
// (a || b || c || ...) && (a || b || d || ...) =>
|
// (a || b || c || ...) && (a || b || d || ...) =>
|
||||||
// ((c || ...) && (d || ...)) || a || b
|
// ((c || ...) && (d || ...)) || a || b
|
||||||
(common + And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
|
(common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // end of And(left, right)
|
} // end of And(left, right)
|
||||||
|
@ -431,26 +431,26 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
|
||||||
// (a && b) || (a && c) => a && (b || c)
|
// (a && b) || (a && c) => a && (b || c)
|
||||||
case _ =>
|
case _ =>
|
||||||
// 1. Split left and right to get the conjunctive predicates,
|
// 1. Split left and right to get the conjunctive predicates,
|
||||||
// i.e. lhsSet = (a, b), rhsSet = (a, c)
|
// i.e. lhs = (a, b), rhs = (a, c)
|
||||||
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
|
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
|
||||||
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
|
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
|
||||||
// 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff)
|
// 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff)
|
||||||
val lhsSet = splitConjunctivePredicates(left).toSet
|
val lhs = splitConjunctivePredicates(left)
|
||||||
val rhsSet = splitConjunctivePredicates(right).toSet
|
val rhs = splitConjunctivePredicates(right)
|
||||||
val common = lhsSet.intersect(rhsSet)
|
val common = lhs.filter(e => rhs.exists(e.semanticEquals(_)))
|
||||||
if (common.isEmpty) {
|
if (common.isEmpty) {
|
||||||
// No common factors, return the original predicate
|
// No common factors, return the original predicate
|
||||||
or
|
or
|
||||||
} else {
|
} else {
|
||||||
val ldiff = lhsSet.diff(common)
|
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_)))
|
||||||
val rdiff = rhsSet.diff(common)
|
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_)))
|
||||||
if (ldiff.isEmpty || rdiff.isEmpty) {
|
if (ldiff.isEmpty || rdiff.isEmpty) {
|
||||||
// (a && b) || (a && b && c && ...) => a && b
|
// (a && b) || (a && b && c && ...) => a && b
|
||||||
common.reduce(And)
|
common.reduce(And)
|
||||||
} else {
|
} else {
|
||||||
// (a && b && c && ...) || (a && b && d && ...) =>
|
// (a && b && c && ...) || (a && b && d && ...) =>
|
||||||
// ((c && ...) || (d && ...)) && a && b
|
// ((c && ...) || (d && ...)) && a && b
|
||||||
(common + Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
|
(common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // end of Or(left, right)
|
} // end of Or(left, right)
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.catalyst.optimizer
|
package org.apache.spark.sql.catalyst.optimizer
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
|
import org.apache.spark.sql.catalyst.analysis.{AnalysisSuite, EliminateSubQueries}
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.plans.logical._
|
import org.apache.spark.sql.catalyst.plans.logical._
|
||||||
import org.apache.spark.sql.catalyst.plans.PlanTest
|
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||||
|
@ -40,29 +40,11 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
|
||||||
|
|
||||||
val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)
|
val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)
|
||||||
|
|
||||||
// The `foldLeft` is required to handle cases like comparing `a && (b && c)` and `(a && b) && c`
|
private def checkCondition(input: Expression, expected: Expression): Unit = {
|
||||||
def compareConditions(e1: Expression, e2: Expression): Boolean = (e1, e2) match {
|
|
||||||
case (lhs: And, rhs: And) =>
|
|
||||||
val lhsSet = splitConjunctivePredicates(lhs).toSet
|
|
||||||
val rhsSet = splitConjunctivePredicates(rhs).toSet
|
|
||||||
lhsSet.foldLeft(rhsSet) { (set, e) =>
|
|
||||||
set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
|
|
||||||
}.isEmpty
|
|
||||||
|
|
||||||
case (lhs: Or, rhs: Or) =>
|
|
||||||
val lhsSet = splitDisjunctivePredicates(lhs).toSet
|
|
||||||
val rhsSet = splitDisjunctivePredicates(rhs).toSet
|
|
||||||
lhsSet.foldLeft(rhsSet) { (set, e) =>
|
|
||||||
set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
|
|
||||||
}.isEmpty
|
|
||||||
|
|
||||||
case (l, r) => l == r
|
|
||||||
}
|
|
||||||
|
|
||||||
def checkCondition(input: Expression, expected: Expression): Unit = {
|
|
||||||
val plan = testRelation.where(input).analyze
|
val plan = testRelation.where(input).analyze
|
||||||
val actual = Optimize.execute(plan).expressions.head
|
val actual = Optimize.execute(plan)
|
||||||
compareConditions(actual, expected)
|
val correctAnswer = testRelation.where(expected).analyze
|
||||||
|
comparePlans(actual, correctAnswer)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("a && a => a") {
|
test("a && a => a") {
|
||||||
|
@ -86,10 +68,8 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
|
||||||
('a === 'b && 'c < 1 && 'a === 5) ||
|
('a === 'b && 'c < 1 && 'a === 5) ||
|
||||||
('a === 'b && 'b < 5 && 'a > 1)
|
('a === 'b && 'b < 5 && 'a > 1)
|
||||||
|
|
||||||
val expected =
|
val expected = 'a === 'b && (
|
||||||
(((('b > 3) && ('c > 2)) ||
|
('b > 3 && 'c > 2) || ('c < 1 && 'a === 5) || ('b < 5 && 'a > 1))
|
||||||
(('c < 1) && ('a === 5))) ||
|
|
||||||
(('b < 5) && ('a > 1))) && ('a === 'b)
|
|
||||||
|
|
||||||
checkCondition(input, expected)
|
checkCondition(input, expected)
|
||||||
}
|
}
|
||||||
|
@ -101,10 +81,27 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
|
||||||
|
|
||||||
checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2)
|
checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2)
|
||||||
|
|
||||||
checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), ('b > 3 && 'c > 5) || 'a < 2)
|
checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5))
|
||||||
|
|
||||||
checkCondition(
|
checkCondition(
|
||||||
('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5),
|
('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5),
|
||||||
('b > 3 && 'a > 3 && 'a < 5) || 'a === 'b)
|
('a === 'b || 'b > 3 && 'a > 3 && 'a < 5))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def caseInsensitiveAnalyse(plan: LogicalPlan) =
|
||||||
|
AnalysisSuite.caseInsensitiveAnalyzer.execute(plan)
|
||||||
|
|
||||||
|
test("(a && b) || (a && c) => a && (b || c) when case insensitive") {
|
||||||
|
val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5)))
|
||||||
|
val actual = Optimize.execute(plan)
|
||||||
|
val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 && ('b > 3 || 'b < 5)))
|
||||||
|
comparePlans(actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("(a || b) && (a || c) => a || (b && c) when case insensitive") {
|
||||||
|
val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5)))
|
||||||
|
val actual = Optimize.execute(plan)
|
||||||
|
val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 || ('b > 3 && 'b < 5)))
|
||||||
|
comparePlans(actual, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue