[SPARK-22238] Fix plan resolution bug caused by EnsureStatefulOpPartitioning

## What changes were proposed in this pull request?

In EnsureStatefulOpPartitioning, we check that the inputRDD to a SparkPlan has the expected partitioning for Streaming Stateful Operators. The problem is that we are not allowed to access this information during planning.
The reason we added that check was because CoalesceExec could actually create RDDs with 0 partitions. We should fix it such that when CoalesceExec says that there is a SinglePartition, there is in fact an inputRDD of 1 partition instead of 0 partitions.

## How was this patch tested?

Regression test in StreamingQuerySuite

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #19467 from brkyvz/stateful-op.
This commit is contained in:
Burak Yavuz 2017-10-14 17:39:15 -07:00 committed by Tathagata Das
parent 014dc84712
commit e8547ffb49
17 changed files with 160 additions and 189 deletions

View file

@ -49,7 +49,9 @@ case object AllTuples extends Distribution
* can mean such tuples are either co-located in the same partition or they will be contiguous
* within a single partition.
*/
case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
case class ClusteredDistribution(
clustering: Seq[Expression],
numPartitions: Option[Int] = None) extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
@ -221,6 +223,7 @@ case object SinglePartition extends Partitioning {
override def satisfies(required: Distribution): Boolean = required match {
case _: BroadcastDistribution => false
case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1)
case _ => true
}
@ -243,8 +246,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredDistribution(requiredClustering) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case ClusteredDistribution(requiredClustering, desiredPartitions) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true
case _ => false
}
@ -289,8 +293,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case OrderedDistribution(requiredOrdering) =>
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering) =>
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
case ClusteredDistribution(requiredClustering, desiredPartitions) =>
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true
case _ => false
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@ -590,10 +590,33 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN
}
protected override def doExecute(): RDD[InternalRow] = {
child.execute().coalesce(numPartitions, shuffle = false)
if (numPartitions == 1 && child.execute().getNumPartitions < 1) {
// Make sure we don't output an RDD with 0 partitions, when claiming that we have a
// `SinglePartition`.
new CoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions)
} else {
child.execute().coalesce(numPartitions, shuffle = false)
}
}
}
object CoalesceExec {
/** A simple RDD with no data, but with the given number of partitions. */
class EmptyRDDWithPartitions(
@transient private val sc: SparkContext,
numPartitions: Int) extends RDD[InternalRow](sc, Nil) {
override def getPartitions: Array[Partition] =
Array.tabulate(numPartitions)(i => EmptyPartition(i))
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
Iterator.empty
}
}
case class EmptyPartition(index: Int) extends Partition
}
/**
* Physical plan for a subquery.
*/

View file

@ -44,13 +44,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
/**
* Given a required distribution, returns a partitioning that satisfies that distribution.
* @param requiredDistribution The distribution that is required by the operator
* @param numPartitions Used when the distribution doesn't require a specific number of partitions
*/
private def createPartitioning(
requiredDistribution: Distribution,
numPartitions: Int): Partitioning = {
requiredDistribution match {
case AllTuples => SinglePartition
case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
case ClusteredDistribution(clustering, desiredPartitions) =>
HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions))
case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
case dist => sys.error(s"Do not know how to satisfy distribution $dist")
}

View file

@ -64,7 +64,7 @@ case class FlatMapGroupsWithStateExec(
/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(groupingAttributes) :: Nil
ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil
/** Ordering needed for using GroupingIterator */
override def requiredChildOrdering: Seq[Seq[SortOrder]] =

View file

@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
/**
@ -61,6 +62,10 @@ class IncrementalExecution(
StreamingDeduplicationStrategy :: Nil
}
private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
.map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
.getOrElse(sparkSession.sessionState.conf.numShufflePartitions)
/**
* See [SPARK-18339]
* Walk the optimized logical plan and replace CurrentBatchTimestamp
@ -83,7 +88,11 @@ class IncrementalExecution(
/** Get the state info of the next stateful operator */
private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = {
StatefulOperatorStateInfo(
checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId)
checkpointLocation,
runId,
statefulOperatorId.getAndIncrement(),
currentBatchId,
numStateStores)
}
/** Locates save/restore pairs surrounding aggregation. */
@ -130,34 +139,8 @@ class IncrementalExecution(
}
}
override def preparations: Seq[Rule[SparkPlan]] =
Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations
override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
/** No need assert supported, as this check has already been done */
override def assertSupported(): Unit = { }
}
object EnsureStatefulOpPartitioning extends Rule[SparkPlan] {
// Needs to be transformUp to avoid extra shuffles
override def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case so: StatefulOperator =>
val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions
val distributions = so.requiredChildDistribution
val children = so.children.zip(distributions).map { case (child, reqDistribution) =>
val expectedPartitioning = reqDistribution match {
case AllTuples => SinglePartition
case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions)
case _ => throw new AnalysisException("Unexpected distribution expected for " +
s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " +
s"$reqDistribution.")
}
if (child.outputPartitioning.guarantees(expectedPartitioning) &&
child.execute().getNumPartitions == expectedPartitioning.numPartitions) {
child
} else {
ShuffleExchangeExec(expectedPartitioning, child)
}
}
so.withNewChildren(children)
}
}

View file

@ -43,10 +43,11 @@ case class StatefulOperatorStateInfo(
checkpointLocation: String,
queryRunId: UUID,
operatorId: Long,
storeVersion: Long) {
storeVersion: Long,
numPartitions: Int) {
override def toString(): String = {
s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " +
s"opId = $operatorId, ver = $storeVersion]"
s"opId = $operatorId, ver = $storeVersion, numPartitions = $numPartitions]"
}
}
@ -239,7 +240,7 @@ case class StateStoreRestoreExec(
if (keyExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(keyExpressions) :: Nil
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
}
}
}
@ -386,7 +387,7 @@ case class StateStoreSaveExec(
if (keyExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(keyExpressions) :: Nil
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
}
}
}
@ -401,7 +402,7 @@ case class StreamingDeduplicateExec(
/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(keyExpressions) :: Nil
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver

View file

@ -368,6 +368,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
testData.select('key).coalesce(1).select('key),
testData.select('key).collect().toSeq)
assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1)
}
test("convert $\"attribute name\" into unresolved attribute") {

View file

@ -425,6 +425,23 @@ class PlannerSuite extends SharedSQLContext {
}
}
test("EnsureRequirements should respect ClusteredDistribution's num partitioning") {
val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13))
// Number of partitions differ
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13)
val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
assert(!childPartitioning.satisfies(distribution))
val inputPlan = DummySparkPlan(
children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
requiredChildDistribution = Seq(distribution),
requiredChildOrdering = Seq(Seq.empty))
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
val shuffle = outputPlan.collect { case e: ShuffleExchangeExec => e }
assert(shuffle.size === 1)
assert(shuffle.head.newPartitioning === finalPartitioning)
}
test("Reuse exchanges") {
val distribution = ClusteredDistribution(Literal(1) :: Nil)
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)

View file

@ -214,7 +214,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
path: String,
queryRunId: UUID = UUID.randomUUID,
version: Int = 0): StatefulOperatorStateInfo = {
StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version)
StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5)
}
private val increment = (store: StateStore, iter: Iterator[String]) => {

View file

@ -160,7 +160,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter
withTempDir { file =>
val storeConf = new StateStoreConf()
val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0)
val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5)
val manager = new SymmetricHashJoinStateManager(
LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration)
try {

View file

@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec}
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.functions._
class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
class DeduplicateSuite extends StateStoreMetricsTest
with BeforeAndAfterAll
with StatefulOperatorTest {
import testImplicits._
@ -41,6 +44,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
AddData(inputData, "a"),
CheckLastBatch("a"),
assertNumStateRows(total = 1, updated = 1),
AssertOnQuery(sq =>
checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))),
AddData(inputData, "a"),
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),
@ -58,6 +63,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
AddData(inputData, "a" -> 1),
CheckLastBatch("a" -> 1),
assertNumStateRows(total = 1, updated = 1),
AssertOnQuery(sq =>
checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))),
AddData(inputData, "a" -> 2), // Dropped
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),

View file

@ -1,138 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming
import java.util.UUID
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo}
import org.apache.spark.sql.test.SharedSQLContext
class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits._
private var baseDf: DataFrame = null
override def beforeAll(): Unit = {
super.beforeAll()
baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char")
}
test("ClusteredDistribution generates Exchange with HashPartitioning") {
testEnsureStatefulOpPartitioning(
baseDf.queryExecution.sparkPlan,
requiredDistribution = keys => ClusteredDistribution(keys),
expectedPartitioning =
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
expectShuffle = true)
}
test("ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning") {
testEnsureStatefulOpPartitioning(
baseDf.coalesce(1).queryExecution.sparkPlan,
requiredDistribution = keys => ClusteredDistribution(keys),
expectedPartitioning =
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
expectShuffle = true)
}
test("AllTuples generates Exchange with SinglePartition") {
testEnsureStatefulOpPartitioning(
baseDf.queryExecution.sparkPlan,
requiredDistribution = _ => AllTuples,
expectedPartitioning = _ => SinglePartition,
expectShuffle = true)
}
test("AllTuples with coalesce(1) doesn't need Exchange") {
testEnsureStatefulOpPartitioning(
baseDf.coalesce(1).queryExecution.sparkPlan,
requiredDistribution = _ => AllTuples,
expectedPartitioning = _ => SinglePartition,
expectShuffle = false)
}
/**
* For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan
* `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to
* ensure the expected partitioning.
*/
private def testEnsureStatefulOpPartitioning(
inputPlan: SparkPlan,
requiredDistribution: Seq[Attribute] => Distribution,
expectedPartitioning: Seq[Attribute] => Partitioning,
expectShuffle: Boolean): Unit = {
val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1)))
val executed = executePlan(operator, OutputMode.Complete())
if (expectShuffle) {
val exchange = executed.children.find(_.isInstanceOf[Exchange])
if (exchange.isEmpty) {
fail(s"Was expecting an exchange but didn't get one in:\n$executed")
}
assert(exchange.get ===
ShuffleExchangeExec(expectedPartitioning(inputPlan.output.take(1)), inputPlan),
s"Exchange didn't have expected properties:\n${exchange.get}")
} else {
assert(!executed.children.exists(_.isInstanceOf[Exchange]),
s"Unexpected exchange found in:\n$executed")
}
}
/** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */
private def executePlan(
p: SparkPlan,
outputMode: OutputMode = OutputMode.Append()): SparkPlan = {
val execution = new IncrementalExecution(
spark,
null,
OutputMode.Complete(),
"chk",
UUID.randomUUID(),
0L,
OffsetSeqMetadata()) {
override lazy val sparkPlan: SparkPlan = p transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
}
}
}
execution.executedPlan
}
}
/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */
case class TestStatefulOperator(
child: SparkPlan,
requiredDist: Distribution) extends UnaryExecNode with StatefulOperator {
override def output: Seq[Attribute] = child.output
override def doExecute(): RDD[InternalRow] = child.execute()
override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil
override def stateInfo: Option[StatefulOperatorStateInfo] = None
}

View file

@ -41,7 +41,9 @@ case class RunningCount(count: Long)
case class Result(key: Long, count: Int)
class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
with BeforeAndAfterAll
with StatefulOperatorTest {
import testImplicits._
import GroupStateImpl._
@ -544,6 +546,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
AddData(inputData, "a"),
CheckLastBatch(("a", "1")),
assertNumStateRows(total = 1, updated = 1),
AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec](
sq, Seq("value"))),
AddData(inputData, "a", "b"),
CheckLastBatch(("a", "2"), ("b", "1")),
assertNumStateRows(total = 2, updated = 2),

View file

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.streaming._
trait StatefulOperatorTest {
/**
* Check that the output partitioning of a child operator of a Stateful operator satisfies the
* distribution that we expect for our Stateful operator.
*/
protected def checkChildOutputHashPartitioning[T <: StatefulOperator](
sq: StreamingQuery,
colNames: Seq[String]): Boolean = {
val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output
val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions
val groupingAttr = attr.filter(a => colNames.contains(a.name))
checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions))
}
/**
* Check that the output partitioning of a child operator of a Stateful operator satisfies the
* distribution that we expect for our Stateful operator.
*/
protected def checkChildOutputPartitioning[T <: StatefulOperator](
sq: StreamingQuery,
expectedPartitioning: Partitioning): Boolean = {
val operator = sq.asInstanceOf[StreamExecution].lastExecution
.executedPlan.collect { case p: T => p }
operator.head.children.forall(
_.outputPartitioning.numPartitions == expectedPartitioning.numPartitions)
}
}

View file

@ -44,7 +44,7 @@ object FailureSingleton {
}
class StreamingAggregationSuite extends StateStoreMetricsTest
with BeforeAndAfterAll with Assertions {
with BeforeAndAfterAll with Assertions with StatefulOperatorTest {
override def afterAll(): Unit = {
super.afterAll()
@ -281,6 +281,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
AddData(inputData, 0L, 5L, 5L, 10L),
AdvanceManualClock(10 * 1000),
CheckLastBatch((0L, 1), (5L, 2), (10L, 1)),
AssertOnQuery(sq =>
checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))),
// advance clock to 20 seconds, should retain keys >= 10
AddData(inputData, 15L, 15L, 20L),
@ -455,8 +457,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
},
AddBlockData(inputSource), // create an empty trigger
CheckLastBatch(1),
AssertOnQuery("Verify addition of exchange operator") { se =>
checkAggregationChain(se, expectShuffling = true, 1)
AssertOnQuery("Verify that no exchange is required") { se =>
checkAggregationChain(se, expectShuffling = false, 1)
},
AddBlockData(inputSource, Seq(2, 3)),
CheckLastBatch(3),

View file

@ -330,7 +330,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
val queryId = UUID.randomUUID
val opId = 0
val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString
val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L)
val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5)
implicit val sqlContext = spark.sqlContext
val coordinatorRef = sqlContext.streams.stateStoreCoordinator

View file

@ -652,6 +652,19 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
}
}
test("SPARK-22238: don't check for RDD partitions during streaming aggregation preparation") {
val stream = MemoryStream[(Int, Int)]
val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char").where("char = 'A'")
val otherDf = stream.toDF().toDF("num", "numSq")
.join(broadcast(baseDf), "num")
.groupBy('char)
.agg(sum('numSq))
testStream(otherDf, OutputMode.Complete())(
AddData(stream, (1, 1), (2, 4)),
CheckLastBatch(("A", 1)))
}
/** Create a streaming DF that only execute one batch in which it returns the given static DF */
private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
require(!triggerDF.isStreaming)