[SPARK-14244][SQL] Don't use SizeBasedWindowFunction.n created on executor side when evaluating window functions
## What changes were proposed in this pull request? `SizeBasedWindowFunction.n` is a global singleton attribute created for evaluating size based aggregate window functions like `CUME_DIST`. However, this attribute gets different expression IDs when created on both driver side and executor side. This PR adds `withPartitionSize` method to `SizeBasedWindowFunction` so that we can easily rewrite `SizeBasedWindowFunction.n` on executor side. ## How was this patch tested? A test case is added in `HiveSparkSubmitSuite`, which supports launching multi-process clusters. Author: Cheng Lian <lian@databricks.com> Closes #12040 from liancheng/spark-14244-fix-sized-window-function.
This commit is contained in:
parent
4fc35e6f5c
commit
27e71a2cd9
|
@ -451,7 +451,11 @@ abstract class RowNumberLike extends AggregateWindowFunction {
|
|||
* A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation.
|
||||
*/
|
||||
trait SizeBasedWindowFunction extends AggregateWindowFunction {
|
||||
protected def n: AttributeReference = SizeBasedWindowFunction.n
|
||||
// It's made a val so that the attribute created on driver side is serialized to executor side.
|
||||
// Otherwise, if it's defined as a function, when it's called on executor side, it actually
|
||||
// returns the singleton value instantiated on executor side, which has different expression ID
|
||||
// from the one created on driver side.
|
||||
val n: AttributeReference = SizeBasedWindowFunction.n
|
||||
}
|
||||
|
||||
object SizeBasedWindowFunction {
|
||||
|
|
|
@ -874,7 +874,8 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame(
|
|||
* processor class.
|
||||
*/
|
||||
private[execution] object AggregateProcessor {
|
||||
def apply(functions: Array[Expression],
|
||||
def apply(
|
||||
functions: Array[Expression],
|
||||
ordinal: Int,
|
||||
inputAttributes: Seq[Attribute],
|
||||
newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection):
|
||||
|
@ -885,11 +886,20 @@ private[execution] object AggregateProcessor {
|
|||
val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp)
|
||||
val imperatives = mutable.Buffer.empty[ImperativeAggregate]
|
||||
|
||||
// SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then
|
||||
// serialized to executor side. These functions all reference a global singleton window
|
||||
// partition size attribute reference, i.e., `SizeBasedWindowFunction.n`. Here we must collect
|
||||
// the singleton instance created on driver side instead of using executor side
|
||||
// `SizeBasedWindowFunction.n` to avoid binding failure caused by mismatching expression ID.
|
||||
val partitionSize: Option[AttributeReference] = {
|
||||
val aggs = functions.flatMap(_.collectFirst { case f: SizeBasedWindowFunction => f })
|
||||
aggs.headOption.map(_.n)
|
||||
}
|
||||
|
||||
// Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to
|
||||
// the aggregation buffer. Note that the ordinal of the partition size value will always be 0.
|
||||
val trackPartitionSize = functions.exists(_.isInstanceOf[SizeBasedWindowFunction])
|
||||
if (trackPartitionSize) {
|
||||
aggBufferAttributes += SizeBasedWindowFunction.n
|
||||
partitionSize.foreach { n =>
|
||||
aggBufferAttributes += n
|
||||
initialValues += NoOp
|
||||
updateExpressions += NoOp
|
||||
}
|
||||
|
@ -920,7 +930,7 @@ private[execution] object AggregateProcessor {
|
|||
// Create the projections.
|
||||
val initialProjection = newMutableProjection(
|
||||
initialValues,
|
||||
Seq(SizeBasedWindowFunction.n))()
|
||||
partitionSize.toSeq)()
|
||||
val updateProjection = newMutableProjection(
|
||||
updateExpressions,
|
||||
aggBufferAttributes ++ inputAttributes)()
|
||||
|
@ -935,7 +945,7 @@ private[execution] object AggregateProcessor {
|
|||
updateProjection,
|
||||
evaluateProjection,
|
||||
imperatives.toArray,
|
||||
trackPartitionSize)
|
||||
partitionSize.isDefined)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -107,7 +107,9 @@ private[hive] class HiveFunctionRegistry(
|
|||
// If there is any other error, we throw an AnalysisException.
|
||||
val errorMessage = s"No handler for Hive udf ${functionInfo.getFunctionClass} " +
|
||||
s"because: ${throwable.getMessage}."
|
||||
throw new AnalysisException(errorMessage)
|
||||
val analysisException = new AnalysisException(errorMessage)
|
||||
analysisException.setStackTrace(throwable.getStackTrace)
|
||||
throw analysisException
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ import org.scalatest.time.SpanSugar._
|
|||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.{QueryTest, SQLContext}
|
||||
import org.apache.spark.sql.{QueryTest, Row, SQLContext}
|
||||
import org.apache.spark.sql.expressions.Window
|
||||
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
|
||||
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
|
||||
|
@ -135,6 +135,19 @@ class HiveSparkSubmitSuite
|
|||
runSparkSubmit(args)
|
||||
}
|
||||
|
||||
test("SPARK-14244 fix window partition size attribute binding failure") {
|
||||
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
|
||||
val args = Seq(
|
||||
"--class", SPARK_14244.getClass.getName.stripSuffix("$"),
|
||||
"--name", "SparkSQLConfTest",
|
||||
"--master", "local-cluster[2,1,1024]",
|
||||
"--conf", "spark.ui.enabled=false",
|
||||
"--conf", "spark.master.rest.enabled=false",
|
||||
"--driver-java-options", "-Dderby.system.durability=test",
|
||||
unusedJar.toString)
|
||||
runSparkSubmit(args)
|
||||
}
|
||||
|
||||
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
|
||||
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
|
||||
private def runSparkSubmit(args: Seq[String]): Unit = {
|
||||
|
@ -378,3 +391,32 @@ object SPARK_11009 extends QueryTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
object SPARK_14244 extends QueryTest {
|
||||
import org.apache.spark.sql.expressions.Window
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
protected var sqlContext: SQLContext = _
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
Utils.configTestLog4j("INFO")
|
||||
|
||||
val sparkContext = new SparkContext(
|
||||
new SparkConf()
|
||||
.set("spark.ui.enabled", "false")
|
||||
.set("spark.sql.shuffle.partitions", "100"))
|
||||
|
||||
val hiveContext = new TestHiveContext(sparkContext)
|
||||
sqlContext = hiveContext
|
||||
|
||||
import hiveContext.implicits._
|
||||
|
||||
try {
|
||||
val window = Window.orderBy('id)
|
||||
val df = sqlContext.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist)
|
||||
checkAnswer(df, Seq(Row(0.5D), Row(1.0D)))
|
||||
} finally {
|
||||
sparkContext.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue