[SPARK-22238] Fix plan resolution bug caused by EnsureStatefulOpPartitioning
## What changes were proposed in this pull request? In EnsureStatefulOpPartitioning, we check that the inputRDD to a SparkPlan has the expected partitioning for Streaming Stateful Operators. The problem is that we are not allowed to access this information during planning. The reason we added that check was because CoalesceExec could actually create RDDs with 0 partitions. We should fix it such that when CoalesceExec says that there is a SinglePartition, there is in fact an inputRDD of 1 partition instead of 0 partitions. ## How was this patch tested? Regression test in StreamingQuerySuite Author: Burak Yavuz <brkyvz@gmail.com> Closes #19467 from brkyvz/stateful-op.
This commit is contained in:
parent
014dc84712
commit
e8547ffb49
|
@ -49,7 +49,9 @@ case object AllTuples extends Distribution
|
|||
* can mean such tuples are either co-located in the same partition or they will be contiguous
|
||||
* within a single partition.
|
||||
*/
|
||||
case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
|
||||
case class ClusteredDistribution(
|
||||
clustering: Seq[Expression],
|
||||
numPartitions: Option[Int] = None) extends Distribution {
|
||||
require(
|
||||
clustering != Nil,
|
||||
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
|
||||
|
@ -221,6 +223,7 @@ case object SinglePartition extends Partitioning {
|
|||
|
||||
override def satisfies(required: Distribution): Boolean = required match {
|
||||
case _: BroadcastDistribution => false
|
||||
case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1)
|
||||
case _ => true
|
||||
}
|
||||
|
||||
|
@ -243,8 +246,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
|
|||
|
||||
override def satisfies(required: Distribution): Boolean = required match {
|
||||
case UnspecifiedDistribution => true
|
||||
case ClusteredDistribution(requiredClustering) =>
|
||||
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
|
||||
case ClusteredDistribution(requiredClustering, desiredPartitions) =>
|
||||
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
|
||||
desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
|
@ -289,8 +293,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
|
|||
case OrderedDistribution(requiredOrdering) =>
|
||||
val minSize = Seq(requiredOrdering.size, ordering.size).min
|
||||
requiredOrdering.take(minSize) == ordering.take(minSize)
|
||||
case ClusteredDistribution(requiredClustering) =>
|
||||
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
|
||||
case ClusteredDistribution(requiredClustering, desiredPartitions) =>
|
||||
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
|
||||
desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
|
|||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.concurrent.duration.Duration
|
||||
|
||||
import org.apache.spark.{InterruptibleIterator, TaskContext}
|
||||
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
|
||||
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
@ -590,10 +590,33 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN
|
|||
}
|
||||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
child.execute().coalesce(numPartitions, shuffle = false)
|
||||
if (numPartitions == 1 && child.execute().getNumPartitions < 1) {
|
||||
// Make sure we don't output an RDD with 0 partitions, when claiming that we have a
|
||||
// `SinglePartition`.
|
||||
new CoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions)
|
||||
} else {
|
||||
child.execute().coalesce(numPartitions, shuffle = false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object CoalesceExec {
|
||||
/** A simple RDD with no data, but with the given number of partitions. */
|
||||
class EmptyRDDWithPartitions(
|
||||
@transient private val sc: SparkContext,
|
||||
numPartitions: Int) extends RDD[InternalRow](sc, Nil) {
|
||||
|
||||
override def getPartitions: Array[Partition] =
|
||||
Array.tabulate(numPartitions)(i => EmptyPartition(i))
|
||||
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
|
||||
Iterator.empty
|
||||
}
|
||||
}
|
||||
|
||||
case class EmptyPartition(index: Int) extends Partition
|
||||
}
|
||||
|
||||
/**
|
||||
* Physical plan for a subquery.
|
||||
*/
|
||||
|
|
|
@ -44,13 +44,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
|
|||
|
||||
/**
|
||||
* Given a required distribution, returns a partitioning that satisfies that distribution.
|
||||
* @param requiredDistribution The distribution that is required by the operator
|
||||
* @param numPartitions Used when the distribution doesn't require a specific number of partitions
|
||||
*/
|
||||
private def createPartitioning(
|
||||
requiredDistribution: Distribution,
|
||||
numPartitions: Int): Partitioning = {
|
||||
requiredDistribution match {
|
||||
case AllTuples => SinglePartition
|
||||
case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
|
||||
case ClusteredDistribution(clustering, desiredPartitions) =>
|
||||
HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions))
|
||||
case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
|
||||
case dist => sys.error(s"Do not know how to satisfy distribution $dist")
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ case class FlatMapGroupsWithStateExec(
|
|||
|
||||
/** Distribute by grouping attributes */
|
||||
override def requiredChildDistribution: Seq[Distribution] =
|
||||
ClusteredDistribution(groupingAttributes) :: Nil
|
||||
ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil
|
||||
|
||||
/** Ordering needed for using GroupingIterator */
|
||||
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
|
|||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
|
||||
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
|
||||
/**
|
||||
|
@ -61,6 +62,10 @@ class IncrementalExecution(
|
|||
StreamingDeduplicationStrategy :: Nil
|
||||
}
|
||||
|
||||
private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
|
||||
.map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
|
||||
.getOrElse(sparkSession.sessionState.conf.numShufflePartitions)
|
||||
|
||||
/**
|
||||
* See [SPARK-18339]
|
||||
* Walk the optimized logical plan and replace CurrentBatchTimestamp
|
||||
|
@ -83,7 +88,11 @@ class IncrementalExecution(
|
|||
/** Get the state info of the next stateful operator */
|
||||
private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = {
|
||||
StatefulOperatorStateInfo(
|
||||
checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId)
|
||||
checkpointLocation,
|
||||
runId,
|
||||
statefulOperatorId.getAndIncrement(),
|
||||
currentBatchId,
|
||||
numStateStores)
|
||||
}
|
||||
|
||||
/** Locates save/restore pairs surrounding aggregation. */
|
||||
|
@ -130,34 +139,8 @@ class IncrementalExecution(
|
|||
}
|
||||
}
|
||||
|
||||
override def preparations: Seq[Rule[SparkPlan]] =
|
||||
Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations
|
||||
override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
|
||||
|
||||
/** No need assert supported, as this check has already been done */
|
||||
override def assertSupported(): Unit = { }
|
||||
}
|
||||
|
||||
object EnsureStatefulOpPartitioning extends Rule[SparkPlan] {
|
||||
// Needs to be transformUp to avoid extra shuffles
|
||||
override def apply(plan: SparkPlan): SparkPlan = plan transformUp {
|
||||
case so: StatefulOperator =>
|
||||
val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions
|
||||
val distributions = so.requiredChildDistribution
|
||||
val children = so.children.zip(distributions).map { case (child, reqDistribution) =>
|
||||
val expectedPartitioning = reqDistribution match {
|
||||
case AllTuples => SinglePartition
|
||||
case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions)
|
||||
case _ => throw new AnalysisException("Unexpected distribution expected for " +
|
||||
s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " +
|
||||
s"$reqDistribution.")
|
||||
}
|
||||
if (child.outputPartitioning.guarantees(expectedPartitioning) &&
|
||||
child.execute().getNumPartitions == expectedPartitioning.numPartitions) {
|
||||
child
|
||||
} else {
|
||||
ShuffleExchangeExec(expectedPartitioning, child)
|
||||
}
|
||||
}
|
||||
so.withNewChildren(children)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,10 +43,11 @@ case class StatefulOperatorStateInfo(
|
|||
checkpointLocation: String,
|
||||
queryRunId: UUID,
|
||||
operatorId: Long,
|
||||
storeVersion: Long) {
|
||||
storeVersion: Long,
|
||||
numPartitions: Int) {
|
||||
override def toString(): String = {
|
||||
s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " +
|
||||
s"opId = $operatorId, ver = $storeVersion]"
|
||||
s"opId = $operatorId, ver = $storeVersion, numPartitions = $numPartitions]"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -239,7 +240,7 @@ case class StateStoreRestoreExec(
|
|||
if (keyExpressions.isEmpty) {
|
||||
AllTuples :: Nil
|
||||
} else {
|
||||
ClusteredDistribution(keyExpressions) :: Nil
|
||||
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -386,7 +387,7 @@ case class StateStoreSaveExec(
|
|||
if (keyExpressions.isEmpty) {
|
||||
AllTuples :: Nil
|
||||
} else {
|
||||
ClusteredDistribution(keyExpressions) :: Nil
|
||||
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -401,7 +402,7 @@ case class StreamingDeduplicateExec(
|
|||
|
||||
/** Distribute by grouping attributes */
|
||||
override def requiredChildDistribution: Seq[Distribution] =
|
||||
ClusteredDistribution(keyExpressions) :: Nil
|
||||
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
|
||||
|
||||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
metrics // force lazy init at driver
|
||||
|
|
|
@ -368,6 +368,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
|
|||
checkAnswer(
|
||||
testData.select('key).coalesce(1).select('key),
|
||||
testData.select('key).collect().toSeq)
|
||||
|
||||
assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1)
|
||||
}
|
||||
|
||||
test("convert $\"attribute name\" into unresolved attribute") {
|
||||
|
|
|
@ -425,6 +425,23 @@ class PlannerSuite extends SharedSQLContext {
|
|||
}
|
||||
}
|
||||
|
||||
test("EnsureRequirements should respect ClusteredDistribution's num partitioning") {
|
||||
val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13))
|
||||
// Number of partitions differ
|
||||
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13)
|
||||
val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
|
||||
assert(!childPartitioning.satisfies(distribution))
|
||||
val inputPlan = DummySparkPlan(
|
||||
children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
|
||||
requiredChildDistribution = Seq(distribution),
|
||||
requiredChildOrdering = Seq(Seq.empty))
|
||||
|
||||
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
|
||||
val shuffle = outputPlan.collect { case e: ShuffleExchangeExec => e }
|
||||
assert(shuffle.size === 1)
|
||||
assert(shuffle.head.newPartitioning === finalPartitioning)
|
||||
}
|
||||
|
||||
test("Reuse exchanges") {
|
||||
val distribution = ClusteredDistribution(Literal(1) :: Nil)
|
||||
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
|
||||
|
|
|
@ -214,7 +214,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
|
|||
path: String,
|
||||
queryRunId: UUID = UUID.randomUUID,
|
||||
version: Int = 0): StatefulOperatorStateInfo = {
|
||||
StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version)
|
||||
StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5)
|
||||
}
|
||||
|
||||
private val increment = (store: StateStore, iter: Iterator[String]) => {
|
||||
|
|
|
@ -160,7 +160,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter
|
|||
|
||||
withTempDir { file =>
|
||||
val storeConf = new StateStoreConf()
|
||||
val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0)
|
||||
val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5)
|
||||
val manager = new SymmetricHashJoinStateManager(
|
||||
LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration)
|
||||
try {
|
||||
|
|
|
@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming
|
|||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, SinglePartition}
|
||||
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
|
||||
import org.apache.spark.sql.execution.streaming.MemoryStream
|
||||
import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec}
|
||||
import org.apache.spark.sql.execution.streaming.state.StateStore
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
|
||||
class DeduplicateSuite extends StateStoreMetricsTest
|
||||
with BeforeAndAfterAll
|
||||
with StatefulOperatorTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
|
@ -41,6 +44,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
|
|||
AddData(inputData, "a"),
|
||||
CheckLastBatch("a"),
|
||||
assertNumStateRows(total = 1, updated = 1),
|
||||
AssertOnQuery(sq =>
|
||||
checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))),
|
||||
AddData(inputData, "a"),
|
||||
CheckLastBatch(),
|
||||
assertNumStateRows(total = 1, updated = 0),
|
||||
|
@ -58,6 +63,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
|
|||
AddData(inputData, "a" -> 1),
|
||||
CheckLastBatch("a" -> 1),
|
||||
assertNumStateRows(total = 1, updated = 1),
|
||||
AssertOnQuery(sq =>
|
||||
checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))),
|
||||
AddData(inputData, "a" -> 2), // Dropped
|
||||
CheckLastBatch(),
|
||||
assertNumStateRows(total = 1, updated = 0),
|
||||
|
|
|
@ -1,138 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.streaming
|
||||
|
||||
import java.util.UUID
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode}
|
||||
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
|
||||
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
private var baseDf: DataFrame = null
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char")
|
||||
}
|
||||
|
||||
test("ClusteredDistribution generates Exchange with HashPartitioning") {
|
||||
testEnsureStatefulOpPartitioning(
|
||||
baseDf.queryExecution.sparkPlan,
|
||||
requiredDistribution = keys => ClusteredDistribution(keys),
|
||||
expectedPartitioning =
|
||||
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
|
||||
expectShuffle = true)
|
||||
}
|
||||
|
||||
test("ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning") {
|
||||
testEnsureStatefulOpPartitioning(
|
||||
baseDf.coalesce(1).queryExecution.sparkPlan,
|
||||
requiredDistribution = keys => ClusteredDistribution(keys),
|
||||
expectedPartitioning =
|
||||
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
|
||||
expectShuffle = true)
|
||||
}
|
||||
|
||||
test("AllTuples generates Exchange with SinglePartition") {
|
||||
testEnsureStatefulOpPartitioning(
|
||||
baseDf.queryExecution.sparkPlan,
|
||||
requiredDistribution = _ => AllTuples,
|
||||
expectedPartitioning = _ => SinglePartition,
|
||||
expectShuffle = true)
|
||||
}
|
||||
|
||||
test("AllTuples with coalesce(1) doesn't need Exchange") {
|
||||
testEnsureStatefulOpPartitioning(
|
||||
baseDf.coalesce(1).queryExecution.sparkPlan,
|
||||
requiredDistribution = _ => AllTuples,
|
||||
expectedPartitioning = _ => SinglePartition,
|
||||
expectShuffle = false)
|
||||
}
|
||||
|
||||
/**
|
||||
* For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan
|
||||
* `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to
|
||||
* ensure the expected partitioning.
|
||||
*/
|
||||
private def testEnsureStatefulOpPartitioning(
|
||||
inputPlan: SparkPlan,
|
||||
requiredDistribution: Seq[Attribute] => Distribution,
|
||||
expectedPartitioning: Seq[Attribute] => Partitioning,
|
||||
expectShuffle: Boolean): Unit = {
|
||||
val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1)))
|
||||
val executed = executePlan(operator, OutputMode.Complete())
|
||||
if (expectShuffle) {
|
||||
val exchange = executed.children.find(_.isInstanceOf[Exchange])
|
||||
if (exchange.isEmpty) {
|
||||
fail(s"Was expecting an exchange but didn't get one in:\n$executed")
|
||||
}
|
||||
assert(exchange.get ===
|
||||
ShuffleExchangeExec(expectedPartitioning(inputPlan.output.take(1)), inputPlan),
|
||||
s"Exchange didn't have expected properties:\n${exchange.get}")
|
||||
} else {
|
||||
assert(!executed.children.exists(_.isInstanceOf[Exchange]),
|
||||
s"Unexpected exchange found in:\n$executed")
|
||||
}
|
||||
}
|
||||
|
||||
/** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */
|
||||
private def executePlan(
|
||||
p: SparkPlan,
|
||||
outputMode: OutputMode = OutputMode.Append()): SparkPlan = {
|
||||
val execution = new IncrementalExecution(
|
||||
spark,
|
||||
null,
|
||||
OutputMode.Complete(),
|
||||
"chk",
|
||||
UUID.randomUUID(),
|
||||
0L,
|
||||
OffsetSeqMetadata()) {
|
||||
override lazy val sparkPlan: SparkPlan = p transform {
|
||||
case plan: SparkPlan =>
|
||||
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
|
||||
plan transformExpressions {
|
||||
case UnresolvedAttribute(Seq(u)) =>
|
||||
inputMap.getOrElse(u,
|
||||
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
|
||||
}
|
||||
}
|
||||
}
|
||||
execution.executedPlan
|
||||
}
|
||||
}
|
||||
|
||||
/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */
|
||||
case class TestStatefulOperator(
|
||||
child: SparkPlan,
|
||||
requiredDist: Distribution) extends UnaryExecNode with StatefulOperator {
|
||||
override def output: Seq[Attribute] = child.output
|
||||
override def doExecute(): RDD[InternalRow] = child.execute()
|
||||
override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil
|
||||
override def stateInfo: Option[StatefulOperatorStateInfo] = None
|
||||
}
|
|
@ -41,7 +41,9 @@ case class RunningCount(count: Long)
|
|||
|
||||
case class Result(key: Long, count: Int)
|
||||
|
||||
class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
|
||||
class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
|
||||
with BeforeAndAfterAll
|
||||
with StatefulOperatorTest {
|
||||
|
||||
import testImplicits._
|
||||
import GroupStateImpl._
|
||||
|
@ -544,6 +546,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
|
|||
AddData(inputData, "a"),
|
||||
CheckLastBatch(("a", "1")),
|
||||
assertNumStateRows(total = 1, updated = 1),
|
||||
AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec](
|
||||
sq, Seq("value"))),
|
||||
AddData(inputData, "a", "b"),
|
||||
CheckLastBatch(("a", "2"), ("b", "1")),
|
||||
assertNumStateRows(total = 2, updated = 2),
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.streaming
|
||||
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.streaming._
|
||||
|
||||
trait StatefulOperatorTest {
|
||||
/**
|
||||
* Check that the output partitioning of a child operator of a Stateful operator satisfies the
|
||||
* distribution that we expect for our Stateful operator.
|
||||
*/
|
||||
protected def checkChildOutputHashPartitioning[T <: StatefulOperator](
|
||||
sq: StreamingQuery,
|
||||
colNames: Seq[String]): Boolean = {
|
||||
val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output
|
||||
val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions
|
||||
val groupingAttr = attr.filter(a => colNames.contains(a.name))
|
||||
checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions))
|
||||
}
|
||||
|
||||
/**
|
||||
* Check that the output partitioning of a child operator of a Stateful operator satisfies the
|
||||
* distribution that we expect for our Stateful operator.
|
||||
*/
|
||||
protected def checkChildOutputPartitioning[T <: StatefulOperator](
|
||||
sq: StreamingQuery,
|
||||
expectedPartitioning: Partitioning): Boolean = {
|
||||
val operator = sq.asInstanceOf[StreamExecution].lastExecution
|
||||
.executedPlan.collect { case p: T => p }
|
||||
operator.head.children.forall(
|
||||
_.outputPartitioning.numPartitions == expectedPartitioning.numPartitions)
|
||||
}
|
||||
}
|
|
@ -44,7 +44,7 @@ object FailureSingleton {
|
|||
}
|
||||
|
||||
class StreamingAggregationSuite extends StateStoreMetricsTest
|
||||
with BeforeAndAfterAll with Assertions {
|
||||
with BeforeAndAfterAll with Assertions with StatefulOperatorTest {
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
super.afterAll()
|
||||
|
@ -281,6 +281,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
|
|||
AddData(inputData, 0L, 5L, 5L, 10L),
|
||||
AdvanceManualClock(10 * 1000),
|
||||
CheckLastBatch((0L, 1), (5L, 2), (10L, 1)),
|
||||
AssertOnQuery(sq =>
|
||||
checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))),
|
||||
|
||||
// advance clock to 20 seconds, should retain keys >= 10
|
||||
AddData(inputData, 15L, 15L, 20L),
|
||||
|
@ -455,8 +457,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
|
|||
},
|
||||
AddBlockData(inputSource), // create an empty trigger
|
||||
CheckLastBatch(1),
|
||||
AssertOnQuery("Verify addition of exchange operator") { se =>
|
||||
checkAggregationChain(se, expectShuffling = true, 1)
|
||||
AssertOnQuery("Verify that no exchange is required") { se =>
|
||||
checkAggregationChain(se, expectShuffling = false, 1)
|
||||
},
|
||||
AddBlockData(inputSource, Seq(2, 3)),
|
||||
CheckLastBatch(3),
|
||||
|
|
|
@ -330,7 +330,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
|
|||
val queryId = UUID.randomUUID
|
||||
val opId = 0
|
||||
val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString
|
||||
val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L)
|
||||
val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5)
|
||||
|
||||
implicit val sqlContext = spark.sqlContext
|
||||
val coordinatorRef = sqlContext.streams.stateStoreCoordinator
|
||||
|
|
|
@ -652,6 +652,19 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-22238: don't check for RDD partitions during streaming aggregation preparation") {
|
||||
val stream = MemoryStream[(Int, Int)]
|
||||
val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char").where("char = 'A'")
|
||||
val otherDf = stream.toDF().toDF("num", "numSq")
|
||||
.join(broadcast(baseDf), "num")
|
||||
.groupBy('char)
|
||||
.agg(sum('numSq))
|
||||
|
||||
testStream(otherDf, OutputMode.Complete())(
|
||||
AddData(stream, (1, 1), (2, 4)),
|
||||
CheckLastBatch(("A", 1)))
|
||||
}
|
||||
|
||||
/** Create a streaming DF that only execute one batch in which it returns the given static DF */
|
||||
private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
|
||||
require(!triggerDF.isStreaming)
|
||||
|
|
Loading…
Reference in a new issue