[SPARK-33427][SQL] Add subexpression elimination for interpreted expression evaluation

### What changes were proposed in this pull request?

This patch proposes to add subexpression elimination for interpreted expression evaluation. Interpreted expression evaluation is used when codegen was not able to work, for example complex schema.

### Why are the changes needed?

Currently we only do subexpression elimination for codegen. For some reasons, we may need to run interpreted expression evaluation. For example, codegen fails to compile and fallbacks to interpreted mode, or complex input/output schema of expressions. It is commonly seen for complex schema from expressions that is possibly caused by the query optimizer too, e.g. SPARK-32945.

We should also support subexpression elimination for interpreted evaluation. That could reduce performance difference when Spark fallbacks from codegen to interpreted expression evaluation, and improve Spark usability.

#### Benchmark

Update `SubExprEliminationBenchmark`:

Before:

```
OpenJDK 64-Bit Server VM 1.8.0_265-b01 on Mac OS X 10.15.6
 Intel(R) Core(TM) i7-9750H CPU  2.60GHz
 from_json as subExpr:                      Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
 -------------------------------------------------------------------------------------------------------------------------
subexpressionElimination on, codegen off           24707          25688         903          0.0   247068775.9       1.0X
```

After:
```
OpenJDK 64-Bit Server VM 1.8.0_265-b01 on Mac OS X 10.15.6
 Intel(R) Core(TM) i7-9750H CPU  2.60GHz
 from_json as subExpr:                      Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
 -------------------------------------------------------------------------------------------------------------------------
subexpressionElimination on, codegen off            2360           2435          87          0.0    23604320.7      11.2X
```

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit test. Benchmark manually.

Closes #30341 from viirya/SPARK-33427.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Liang-Chi Hsieh 2020-11-17 14:29:37 +00:00 committed by Wenchen Fan
parent 09bb9bedcd
commit 928348408e
6 changed files with 281 additions and 10 deletions

View file

@ -20,6 +20,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{UserDefinedType, _}
import org.apache.spark.unsafe.Platform
@ -33,6 +34,15 @@ import org.apache.spark.unsafe.Platform
class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection {
import InterpretedUnsafeProjection._
private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled
private[this] lazy val runtime =
new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
private[this] val exprs = if (subExprEliminationEnabled) {
runtime.proxyExpressions(expressions)
} else {
expressions.toSeq
}
/** Number of (top level) fields in the resulting row. */
private[this] val numFields = expressions.length
@ -63,17 +73,21 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
}
override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
exprs.foreach(_.foreach {
case n: Nondeterministic => n.initialize(partitionIndex)
case _ =>
})
}
override def apply(row: InternalRow): UnsafeRow = {
if (subExprEliminationEnabled) {
runtime.setInput(row)
}
// Put the expression results in the intermediate row.
var i = 0
while (i < numFields) {
values(i) = expressions(i).eval(row)
values(i) = exprs(i).eval(row)
i += 1
}

View file

@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.util.IdentityHashMap
import scala.collection.JavaConverters._
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.DataType
/**
* This class helps subexpression elimination for interpreted evaluation
* such as `InterpretedUnsafeProjection`. It maintains an evaluation cache.
* This class wraps `ExpressionProxy` around given expressions. The `ExpressionProxy`
* intercepts expression evaluation and loads from the cache first.
*/
class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
// The id assigned to `ExpressionProxy`. `SubExprEvaluationRuntime` will use assigned ids of
// `ExpressionProxy` to decide the equality when loading from cache. `SubExprEvaluationRuntime`
// won't be use by multi-threads so we don't need to consider concurrency here.
private var proxyExpressionCurrentId = 0
private[sql] val cache: LoadingCache[ExpressionProxy, ResultProxy] = CacheBuilder.newBuilder()
.maximumSize(cacheMaxEntries)
.build(
new CacheLoader[ExpressionProxy, ResultProxy]() {
override def load(expr: ExpressionProxy): ResultProxy = {
ResultProxy(expr.proxyEval(currentInput))
}
})
private var currentInput: InternalRow = null
def getEval(proxy: ExpressionProxy): Any = try {
cache.get(proxy).result
} catch {
// Cache.get() may wrap the original exception. See the following URL
// http://google.github.io/guava/releases/14.0/api/docs/com/google/common/cache/
// Cache.html#get(K,%20java.util.concurrent.Callable)
case e @ (_: UncheckedExecutionException | _: ExecutionError) =>
throw e.getCause
}
/**
* Sets given input row as current row for evaluating expressions. This cleans up the cache
* too as new input comes.
*/
def setInput(input: InternalRow = null): Unit = {
currentInput = input
cache.invalidateAll()
}
/**
* Recursively replaces expression with its proxy expression in `proxyMap`.
*/
private def replaceWithProxy(
expr: Expression,
proxyMap: IdentityHashMap[Expression, ExpressionProxy]): Expression = {
if (proxyMap.containsKey(expr)) {
proxyMap.get(expr)
} else {
expr.mapChildren(replaceWithProxy(_, proxyMap))
}
}
/**
* Finds subexpressions and wraps them with `ExpressionProxy`.
*/
def proxyExpressions(expressions: Seq[Expression]): Seq[Expression] = {
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
expressions.foreach(equivalentExpressions.addExprTree(_))
val proxyMap = new IdentityHashMap[Expression, ExpressionProxy]
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
commonExprs.foreach { e =>
val expr = e.head
val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this)
proxyExpressionCurrentId += 1
proxyMap.putAll(e.map(_ -> proxy).toMap.asJava)
}
// Only adding proxy if we find subexpressions.
if (!proxyMap.isEmpty) {
expressions.map(replaceWithProxy(_, proxyMap))
} else {
expressions
}
}
}
/**
* A proxy for an catalyst `Expression`. Given a runtime object `SubExprEvaluationRuntime`,
* when this is asked to evaluate, it will load from the evaluation cache in the runtime first.
*/
case class ExpressionProxy(
child: Expression,
id: Int,
runtime: SubExprEvaluationRuntime) extends Expression {
final override def dataType: DataType = child.dataType
final override def nullable: Boolean = child.nullable
final override def children: Seq[Expression] = child :: Nil
// `ExpressionProxy` is for interpreted expression evaluation only. So cannot `doGenCode`.
final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
throw new UnsupportedOperationException(s"Cannot generate code for expression: $this")
def proxyEval(input: InternalRow = null): Any = child.eval(input)
override def eval(input: InternalRow = null): Any = runtime.getEval(this)
override def equals(obj: Any): Boolean = obj match {
case other: ExpressionProxy => this.id == other.id
case _ => false
}
override def hashCode(): Int = this.id.hashCode()
}
/**
* A simple wrapper for holding `Any` in the cache of `SubExprEvaluationRuntime`.
*/
case class ResultProxy(result: Any)

View file

@ -539,6 +539,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
val SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES =
buildConf("spark.sql.subexpressionElimination.cache.maxEntries")
.internal()
.doc("The maximum entries of the cache used for interpreted subexpression elimination.")
.version("3.1.0")
.intConf
.checkValue(_ >= 0, "The maximum must not be negative")
.createWithDefault(100)
val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive")
.internal()
.doc("Whether the query analyzer should be case sensitive or not. " +
@ -3258,6 +3267,9 @@ class SQLConf extends Serializable with Logging {
def subexpressionEliminationEnabled: Boolean =
getConf(SUBEXPRESSION_ELIMINATION_ENABLED)
def subexpressionEliminationCacheMaxEntries: Int =
getConf(SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES)
def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR)

View file

@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType
class SubExprEvaluationRuntimeSuite extends SparkFunSuite {
test("Evaluate ExpressionProxy should create cached result") {
val runtime = new SubExprEvaluationRuntime(1)
val proxy = ExpressionProxy(Literal(1), 0, runtime)
assert(runtime.cache.size() == 0)
proxy.eval()
assert(runtime.cache.size() == 1)
assert(runtime.cache.get(proxy) == ResultProxy(1))
}
test("SubExprEvaluationRuntime cannot exceed configured max entries") {
val runtime = new SubExprEvaluationRuntime(2)
assert(runtime.cache.size() == 0)
val proxy1 = ExpressionProxy(Literal(1), 0, runtime)
proxy1.eval()
assert(runtime.cache.size() == 1)
assert(runtime.cache.get(proxy1) == ResultProxy(1))
val proxy2 = ExpressionProxy(Literal(2), 1, runtime)
proxy2.eval()
assert(runtime.cache.size() == 2)
assert(runtime.cache.get(proxy2) == ResultProxy(2))
val proxy3 = ExpressionProxy(Literal(3), 2, runtime)
proxy3.eval()
assert(runtime.cache.size() == 2)
assert(runtime.cache.get(proxy3) == ResultProxy(3))
}
test("setInput should empty cached result") {
val runtime = new SubExprEvaluationRuntime(2)
val proxy1 = ExpressionProxy(Literal(1), 0, runtime)
assert(runtime.cache.size() == 0)
proxy1.eval()
assert(runtime.cache.size() == 1)
assert(runtime.cache.get(proxy1) == ResultProxy(1))
val proxy2 = ExpressionProxy(Literal(2), 1, runtime)
proxy2.eval()
assert(runtime.cache.size() == 2)
assert(runtime.cache.get(proxy2) == ResultProxy(2))
runtime.setInput()
assert(runtime.cache.size() == 0)
}
test("Wrap ExpressionProxy on subexpressions") {
val runtime = new SubExprEvaluationRuntime(1)
val one = Literal(1)
val two = Literal(2)
val mul = Multiply(one, two)
val mul2 = Multiply(mul, mul)
val sqrt = Sqrt(mul2)
val sum = Add(mul2, sqrt)
// ( (one * two) * (one * two) ) + sqrt( (one * two) * (one * two) )
val proxyExpressions = runtime.proxyExpressions(Seq(sum))
val proxys = proxyExpressions.flatMap(_.collect {
case p: ExpressionProxy => p
})
// ( (one * two) * (one * two) )
assert(proxys.size == 2)
val expected = ExpressionProxy(mul2, 0, runtime)
assert(proxys.forall(_ == expected))
}
test("ExpressionProxy won't be on non deterministic") {
val runtime = new SubExprEvaluationRuntime(1)
val sum = Add(Rand(0), Rand(0))
val proxys = runtime.proxyExpressions(Seq(sum, sum)).flatMap(_.collect {
case p: ExpressionProxy => p
})
assert(proxys.isEmpty)
}
}

View file

@ -7,9 +7,9 @@ OpenJDK 64-Bit Server VM 11.0.9+11 on Mac OS X 10.15.6
Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
from_json as subExpr: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------
subexpressionElimination off, codegen on 26809 27731 898 0.0 268094225.4 1.0X
subexpressionElimination off, codegen off 25117 26612 1357 0.0 251166638.4 1.1X
subexpressionElimination on, codegen on 2582 2906 282 0.0 25819408.7 10.4X
subexpressionElimination on, codegen off 25635 26131 804 0.0 256346873.1 1.0X
subexpressionElimination off, codegen on 25932 26908 916 0.0 259320042.3 1.0X
subexpressionElimination off, codegen off 26085 26159 65 0.0 260848905.0 1.0X
subexpressionElimination on, codegen on 2860 2939 72 0.0 28603312.9 9.1X
subexpressionElimination on, codegen off 2517 2617 93 0.0 25165157.7 10.3X

View file

@ -7,9 +7,9 @@ OpenJDK 64-Bit Server VM 1.8.0_265-b01 on Mac OS X 10.15.6
Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
from_json as subExpr: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------
subexpressionElimination off, codegen on 24841 25365 803 0.0 248412787.5 1.0X
subexpressionElimination off, codegen off 25344 26205 941 0.0 253442656.5 1.0X
subexpressionElimination on, codegen on 2883 3019 119 0.0 28833086.8 8.6X
subexpressionElimination on, codegen off 24707 25688 903 0.0 247068775.9 1.0X
subexpressionElimination off, codegen on 26503 27622 1937 0.0 265033362.4 1.0X
subexpressionElimination off, codegen off 24920 25376 430 0.0 249196978.2 1.1X
subexpressionElimination on, codegen on 2421 2466 39 0.0 24213606.1 10.9X
subexpressionElimination on, codegen off 2360 2435 87 0.0 23604320.7 11.2X