[SPARK-26147][SQL] only pull out unevaluable python udf from join condition

## What changes were proposed in this pull request?

https://github.com/apache/spark/pull/22326 made a mistake that, not all python UDFs are unevaluable in join condition. Only python UDFs that refer to attributes from both join side are unevaluable.

This PR fixes this mistake.

## How was this patch tested?

a new test

Closes #23153 from cloud-fan/join.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2018-11-28 20:38:42 +08:00
parent 438f8fd675
commit affe80958d
3 changed files with 106 additions and 48 deletions

View file

@ -209,6 +209,18 @@ class UDFTests(ReusedSQLTestCase):
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, b=1)])
def test_udf_in_left_outer_join_condition(self):
# regression test for SPARK-26147
from pyspark.sql.functions import udf, col
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a: str(a), StringType())
# The join condition can't be pushed down, as it refers to attributes from both sides.
# The Python UDF only refer to attributes from one side, so it's evaluable.
df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, b=1)])
def test_udf_in_left_semi_join_condition(self):
# regression test for SPARK-25314
from pyspark.sql.functions import udf

View file

@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
}
/**
* PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
* and pull them out from join condition. For python udf accessing attributes from only one side,
* they are pushed down by operation push down rules. If not (e.g. user disables filter push
* down rules), we need to pull them out in this rule too.
* PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides.
* See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them
* out from join condition.
*/
object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
def hasPythonUDF(expression: Expression): Boolean = {
expression.collectFirst { case udf: PythonUDF => udf }.isDefined
private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = {
expr.find { e =>
PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right)
}.isDefined
}
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case j @ Join(_, _, joinType, condition)
if condition.isDefined && hasPythonUDF(condition.get) =>
case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) =>
if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
// The current strategy only support InnerLike and LeftSemi join because for other type,
// it breaks SQL semantic if we run the join condition as a filter after join. If we pass
@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH
}
// If condition expression contains python udf, it will be moved out from
// the new join conditions.
val (udf, rest) =
splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j))
val newCondition = if (rest.isEmpty) {
logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," +
s" it will be moved out and the join plan will be turned to cross join.")
None
} else {

View file

@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.optimizer
import org.scalatest.Matchers._
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
@ -28,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.types.{BooleanType, IntegerType}
class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
@ -40,13 +38,29 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
CheckCartesianProducts) :: Nil
}
val testRelationLeft = LocalRelation('a.int, 'b.int)
val testRelationRight = LocalRelation('c.int, 'd.int)
val attrA = 'a.int
val attrB = 'b.int
val attrC = 'c.int
val attrD = 'd.int
// Dummy python UDF for testing. Unable to execute.
val pythonUDF = PythonUDF("pythonUDF", null,
val testRelationLeft = LocalRelation(attrA, attrB)
val testRelationRight = LocalRelation(attrC, attrD)
// This join condition refers to attributes from 2 tables, but the PythonUDF inside it only
// refer to attributes from one side.
val evaluableJoinCond = {
val pythonUDF = PythonUDF("evaluable", null,
IntegerType,
Seq(attrA),
PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)
pythonUDF === attrC
}
// This join condition is a PythonUDF which refers to attributes from 2 tables.
val unevaluableJoinCond = PythonUDF("unevaluable", null,
BooleanType,
Seq.empty,
Seq(attrA, attrC),
PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)
@ -66,62 +80,76 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
}
}
test("inner join condition with python udf only") {
val query = testRelationLeft.join(
test("inner join condition with python udf") {
val query1 = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(pythonUDF))
val expected = testRelationLeft.join(
condition = Some(unevaluableJoinCond))
val expected1 = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = None).where(pythonUDF).analyze
comparePlanWithCrossJoinEnable(query, expected)
condition = None).where(unevaluableJoinCond).analyze
comparePlanWithCrossJoinEnable(query1, expected1)
// evaluable PythonUDF will not be touched
val query2 = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(evaluableJoinCond))
comparePlans(Optimize.execute(query2), query2)
}
test("left semi join condition with python udf only") {
val query = testRelationLeft.join(
test("left semi join condition with python udf") {
val query1 = testRelationLeft.join(
testRelationRight,
joinType = LeftSemi,
condition = Some(pythonUDF))
val expected = testRelationLeft.join(
condition = Some(unevaluableJoinCond))
val expected1 = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = None).where(pythonUDF).select('a, 'b).analyze
comparePlanWithCrossJoinEnable(query, expected)
condition = None).where(unevaluableJoinCond).select('a, 'b).analyze
comparePlanWithCrossJoinEnable(query1, expected1)
// evaluable PythonUDF will not be touched
val query2 = testRelationLeft.join(
testRelationRight,
joinType = LeftSemi,
condition = Some(evaluableJoinCond))
comparePlans(Optimize.execute(query2), query2)
}
test("python udf and common condition") {
test("unevaluable python udf and common condition") {
val query = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(pythonUDF && 'a.attr === 'c.attr))
condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr))
val expected = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze
condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze
val optimized = Optimize.execute(query.analyze)
comparePlans(optimized, expected)
}
test("python udf or common condition") {
test("unevaluable python udf or common condition") {
val query = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(pythonUDF || 'a.attr === 'c.attr))
condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr))
val expected = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze
condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze
comparePlanWithCrossJoinEnable(query, expected)
}
test("pull out whole complex condition with multiple python udf") {
test("pull out whole complex condition with multiple unevaluable python udf") {
val pythonUDF1 = PythonUDF("pythonUDF1", null,
BooleanType,
Seq.empty,
Seq(attrA, attrC),
PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)
val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1
val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1
val query = testRelationLeft.join(
testRelationRight,
@ -134,13 +162,13 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
comparePlanWithCrossJoinEnable(query, expected)
}
test("partial pull out complex condition with multiple python udf") {
test("partial pull out complex condition with multiple unevaluable python udf") {
val pythonUDF1 = PythonUDF("pythonUDF1", null,
BooleanType,
Seq.empty,
Seq(attrA, attrC),
PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)
val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr
val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr
val query = testRelationLeft.join(
testRelationRight,
@ -149,23 +177,41 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
val expected = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze
condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze
val optimized = Optimize.execute(query.analyze)
comparePlans(optimized, expected)
}
test("pull out unevaluable python udf when it's mixed with evaluable one") {
val query = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(evaluableJoinCond && unevaluableJoinCond))
val expected = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(evaluableJoinCond)).where(unevaluableJoinCond).analyze
val optimized = Optimize.execute(query.analyze)
comparePlans(optimized, expected)
}
test("throw an exception for not support join type") {
for (joinType <- unsupportedJoinTypes) {
val thrownException = the [AnalysisException] thrownBy {
val e = intercept[AnalysisException] {
val query = testRelationLeft.join(
testRelationRight,
joinType,
condition = Some(pythonUDF))
condition = Some(unevaluableJoinCond))
Optimize.execute(query.analyze)
}
assert(thrownException.message.contentEquals(
assert(e.message.contentEquals(
s"Using PythonUDF in join condition of join type $joinType is not supported."))
}
}
}
val query2 = testRelationLeft.join(
testRelationRight,
joinType,
condition = Some(evaluableJoinCond))
comparePlans(Optimize.execute(query2), query2)
}
}
}