[SPARK-26147][SQL] only pull out unevaluable python udf from join condition
## What changes were proposed in this pull request? https://github.com/apache/spark/pull/22326 made a mistake that, not all python UDFs are unevaluable in join condition. Only python UDFs that refer to attributes from both join side are unevaluable. This PR fixes this mistake. ## How was this patch tested? a new test Closes #23153 from cloud-fan/join. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
438f8fd675
commit
affe80958d
|
@ -209,6 +209,18 @@ class UDFTests(ReusedSQLTestCase):
|
|||
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
|
||||
self.assertEqual(df.collect(), [Row(a=1, b=1)])
|
||||
|
||||
def test_udf_in_left_outer_join_condition(self):
|
||||
# regression test for SPARK-26147
|
||||
from pyspark.sql.functions import udf, col
|
||||
left = self.spark.createDataFrame([Row(a=1)])
|
||||
right = self.spark.createDataFrame([Row(b=1)])
|
||||
f = udf(lambda a: str(a), StringType())
|
||||
# The join condition can't be pushed down, as it refers to attributes from both sides.
|
||||
# The Python UDF only refer to attributes from one side, so it's evaluable.
|
||||
df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
|
||||
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
|
||||
self.assertEqual(df.collect(), [Row(a=1, b=1)])
|
||||
|
||||
def test_udf_in_left_semi_join_condition(self):
|
||||
# regression test for SPARK-25314
|
||||
from pyspark.sql.functions import udf
|
||||
|
|
|
@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
|
|||
}
|
||||
|
||||
/**
|
||||
* PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
|
||||
* and pull them out from join condition. For python udf accessing attributes from only one side,
|
||||
* they are pushed down by operation push down rules. If not (e.g. user disables filter push
|
||||
* down rules), we need to pull them out in this rule too.
|
||||
* PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides.
|
||||
* See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them
|
||||
* out from join condition.
|
||||
*/
|
||||
object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
|
||||
def hasPythonUDF(expression: Expression): Boolean = {
|
||||
expression.collectFirst { case udf: PythonUDF => udf }.isDefined
|
||||
|
||||
private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = {
|
||||
expr.find { e =>
|
||||
PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right)
|
||||
}.isDefined
|
||||
}
|
||||
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
|
||||
case j @ Join(_, _, joinType, condition)
|
||||
if condition.isDefined && hasPythonUDF(condition.get) =>
|
||||
case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) =>
|
||||
if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
|
||||
// The current strategy only support InnerLike and LeftSemi join because for other type,
|
||||
// it breaks SQL semantic if we run the join condition as a filter after join. If we pass
|
||||
|
@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH
|
|||
}
|
||||
// If condition expression contains python udf, it will be moved out from
|
||||
// the new join conditions.
|
||||
val (udf, rest) =
|
||||
splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
|
||||
val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j))
|
||||
val newCondition = if (rest.isEmpty) {
|
||||
logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
|
||||
logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," +
|
||||
s" it will be moved out and the join plan will be turned to cross join.")
|
||||
None
|
||||
} else {
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.scalatest.Matchers._
|
||||
|
||||
import org.apache.spark.api.python.PythonEvalType
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
|
@ -28,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
|
|||
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||
import org.apache.spark.sql.internal.SQLConf._
|
||||
import org.apache.spark.sql.types.BooleanType
|
||||
import org.apache.spark.sql.types.{BooleanType, IntegerType}
|
||||
|
||||
class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
|
||||
|
||||
|
@ -40,13 +38,29 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
|
|||
CheckCartesianProducts) :: Nil
|
||||
}
|
||||
|
||||
val testRelationLeft = LocalRelation('a.int, 'b.int)
|
||||
val testRelationRight = LocalRelation('c.int, 'd.int)
|
||||
val attrA = 'a.int
|
||||
val attrB = 'b.int
|
||||
val attrC = 'c.int
|
||||
val attrD = 'd.int
|
||||
|
||||
// Dummy python UDF for testing. Unable to execute.
|
||||
val pythonUDF = PythonUDF("pythonUDF", null,
|
||||
val testRelationLeft = LocalRelation(attrA, attrB)
|
||||
val testRelationRight = LocalRelation(attrC, attrD)
|
||||
|
||||
// This join condition refers to attributes from 2 tables, but the PythonUDF inside it only
|
||||
// refer to attributes from one side.
|
||||
val evaluableJoinCond = {
|
||||
val pythonUDF = PythonUDF("evaluable", null,
|
||||
IntegerType,
|
||||
Seq(attrA),
|
||||
PythonEvalType.SQL_BATCHED_UDF,
|
||||
udfDeterministic = true)
|
||||
pythonUDF === attrC
|
||||
}
|
||||
|
||||
// This join condition is a PythonUDF which refers to attributes from 2 tables.
|
||||
val unevaluableJoinCond = PythonUDF("unevaluable", null,
|
||||
BooleanType,
|
||||
Seq.empty,
|
||||
Seq(attrA, attrC),
|
||||
PythonEvalType.SQL_BATCHED_UDF,
|
||||
udfDeterministic = true)
|
||||
|
||||
|
@ -66,62 +80,76 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
|
|||
}
|
||||
}
|
||||
|
||||
test("inner join condition with python udf only") {
|
||||
val query = testRelationLeft.join(
|
||||
test("inner join condition with python udf") {
|
||||
val query1 = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = Some(pythonUDF))
|
||||
val expected = testRelationLeft.join(
|
||||
condition = Some(unevaluableJoinCond))
|
||||
val expected1 = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = None).where(pythonUDF).analyze
|
||||
comparePlanWithCrossJoinEnable(query, expected)
|
||||
condition = None).where(unevaluableJoinCond).analyze
|
||||
comparePlanWithCrossJoinEnable(query1, expected1)
|
||||
|
||||
// evaluable PythonUDF will not be touched
|
||||
val query2 = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = Some(evaluableJoinCond))
|
||||
comparePlans(Optimize.execute(query2), query2)
|
||||
}
|
||||
|
||||
test("left semi join condition with python udf only") {
|
||||
val query = testRelationLeft.join(
|
||||
test("left semi join condition with python udf") {
|
||||
val query1 = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = LeftSemi,
|
||||
condition = Some(pythonUDF))
|
||||
val expected = testRelationLeft.join(
|
||||
condition = Some(unevaluableJoinCond))
|
||||
val expected1 = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = None).where(pythonUDF).select('a, 'b).analyze
|
||||
comparePlanWithCrossJoinEnable(query, expected)
|
||||
condition = None).where(unevaluableJoinCond).select('a, 'b).analyze
|
||||
comparePlanWithCrossJoinEnable(query1, expected1)
|
||||
|
||||
// evaluable PythonUDF will not be touched
|
||||
val query2 = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = LeftSemi,
|
||||
condition = Some(evaluableJoinCond))
|
||||
comparePlans(Optimize.execute(query2), query2)
|
||||
}
|
||||
|
||||
test("python udf and common condition") {
|
||||
test("unevaluable python udf and common condition") {
|
||||
val query = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = Some(pythonUDF && 'a.attr === 'c.attr))
|
||||
condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr))
|
||||
val expected = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze
|
||||
condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze
|
||||
val optimized = Optimize.execute(query.analyze)
|
||||
comparePlans(optimized, expected)
|
||||
}
|
||||
|
||||
test("python udf or common condition") {
|
||||
test("unevaluable python udf or common condition") {
|
||||
val query = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = Some(pythonUDF || 'a.attr === 'c.attr))
|
||||
condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr))
|
||||
val expected = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze
|
||||
condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze
|
||||
comparePlanWithCrossJoinEnable(query, expected)
|
||||
}
|
||||
|
||||
test("pull out whole complex condition with multiple python udf") {
|
||||
test("pull out whole complex condition with multiple unevaluable python udf") {
|
||||
val pythonUDF1 = PythonUDF("pythonUDF1", null,
|
||||
BooleanType,
|
||||
Seq.empty,
|
||||
Seq(attrA, attrC),
|
||||
PythonEvalType.SQL_BATCHED_UDF,
|
||||
udfDeterministic = true)
|
||||
val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1
|
||||
val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1
|
||||
|
||||
val query = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
|
@ -134,13 +162,13 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
|
|||
comparePlanWithCrossJoinEnable(query, expected)
|
||||
}
|
||||
|
||||
test("partial pull out complex condition with multiple python udf") {
|
||||
test("partial pull out complex condition with multiple unevaluable python udf") {
|
||||
val pythonUDF1 = PythonUDF("pythonUDF1", null,
|
||||
BooleanType,
|
||||
Seq.empty,
|
||||
Seq(attrA, attrC),
|
||||
PythonEvalType.SQL_BATCHED_UDF,
|
||||
udfDeterministic = true)
|
||||
val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr
|
||||
val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr
|
||||
|
||||
val query = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
|
@ -149,23 +177,41 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
|
|||
val expected = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze
|
||||
condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze
|
||||
val optimized = Optimize.execute(query.analyze)
|
||||
comparePlans(optimized, expected)
|
||||
}
|
||||
|
||||
test("pull out unevaluable python udf when it's mixed with evaluable one") {
|
||||
val query = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = Some(evaluableJoinCond && unevaluableJoinCond))
|
||||
val expected = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType = Inner,
|
||||
condition = Some(evaluableJoinCond)).where(unevaluableJoinCond).analyze
|
||||
val optimized = Optimize.execute(query.analyze)
|
||||
comparePlans(optimized, expected)
|
||||
}
|
||||
|
||||
test("throw an exception for not support join type") {
|
||||
for (joinType <- unsupportedJoinTypes) {
|
||||
val thrownException = the [AnalysisException] thrownBy {
|
||||
val e = intercept[AnalysisException] {
|
||||
val query = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType,
|
||||
condition = Some(pythonUDF))
|
||||
condition = Some(unevaluableJoinCond))
|
||||
Optimize.execute(query.analyze)
|
||||
}
|
||||
assert(thrownException.message.contentEquals(
|
||||
assert(e.message.contentEquals(
|
||||
s"Using PythonUDF in join condition of join type $joinType is not supported."))
|
||||
|
||||
val query2 = testRelationLeft.join(
|
||||
testRelationRight,
|
||||
joinType,
|
||||
condition = Some(evaluableJoinCond))
|
||||
comparePlans(Optimize.execute(query2), query2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue