diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index e8e42d72d4..52c2971b73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -334,7 +334,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup // try to compile, helpful for debug val cleanedSource = CodeFormatter.stripExtraNewLines(source) - // println(s"${CodeFormatter.format(cleanedSource)}") + logDebug(s"${CodeFormatter.format(cleanedSource)}") CodeGenerator.compile(cleanedSource) val rdds = child.asInstanceOf[CodegenSupport].upstreams() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6ebbc8be6f..6e2a5aa4f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -74,8 +74,27 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) } -case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode with CodegenSupport { - override def output: Seq[Attribute] = child.output +case class Filter(condition: Expression, child: SparkPlan) + extends UnaryNode with CodegenSupport with PredicateHelper { + + // Split out all the IsNotNulls from condition. + private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { + case IsNotNull(a) if child.output.contains(a) => true + case _ => false + } + + // The columns that will filtered out by `IsNotNull` could be considered as not nullable. + private val notNullAttributes = notNullPreds.flatMap(_.references) + + override def output: Seq[Attribute] = { + child.output.map { a => + if (a.nullable && notNullAttributes.contains(a)) { + a.withNullability(false) + } else { + a + } + } + } private[sql] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -90,20 +109,42 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val expr = ExpressionCanonicalizer.execute( - BindReferences.bindReference(condition, child.output)) + + // filter out the nulls + val filterOutNull = notNullAttributes.map { a => + val idx = child.output.indexOf(a) + s"if (${input(idx).isNull}) continue;" + }.mkString("\n") + ctx.currentVars = input - val eval = expr.gen(ctx) - val nullCheck = if (expr.nullable) { - s"!${eval.isNull} &&" - } else { - s"" + val predicates = otherPreds.map { e => + val bound = ExpressionCanonicalizer.execute( + BindReferences.bindReference(e, output)) + val ev = bound.gen(ctx) + val nullCheck = if (bound.nullable) { + s"${ev.isNull} || " + } else { + s"" + } + s""" + |${ev.code} + |if (${nullCheck}!${ev.value}) continue; + """.stripMargin + }.mkString("\n") + + // Reset the isNull to false for the not-null columns, then the followed operators could + // generate better code (remove dead branches). + val resultVars = input.zipWithIndex.map { case (ev, i) => + if (notNullAttributes.contains(child.output(i))) { + ev.isNull = "false" + } + ev } s""" - |${eval.code} - |if (!($nullCheck ${eval.value})) continue; + |$filterOutNull + |$predicates |$numOutput.add(1); - |${consume(ctx, ctx.currentVars)} + |${consume(ctx, resultVars)} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index d83486df02..4143e944e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -55,7 +55,9 @@ case class BroadcastNestedLoopJoin( UnsafeProjection.create(output, output) } else { // Always put the stream side on left to simplify implementation - UnsafeProjection.create(output, streamed.output ++ broadcast.output) + // both of left and right side could be null + UnsafeProjection.create( + output, (streamed.output ++ broadcast.output).map(_.withNullability(true))) } }