[SPARK-32862][SS] Left semi stream-stream join

### What changes were proposed in this pull request?

This is to support left semi join in stream-stream join. The implementation of left semi join is (mostly in `StreamingSymmetricHashJoinExec` and `SymmetricHashJoinStateManager`):
* For left side input row, check if there's a match on right side state store.
  * if there's a match, output the left side row, but do not put the row in left side state store (no need to put in state store).
  * if there's no match, output nothing, but put the row in left side state store (with "matched" field to set to false in state store).
* For right side input row, check if there's a match on left side state store.
  * For all matched left rows in state store, output the rows with "matched" field as false. Set all left rows with "matched" field to be true. Only output the left side rows matched for the first time to guarantee left semi join semantics.
* State store eviction: evict rows from left/right side state store below watermark, same as inner join.

Note a followup optimization can be to evict matched left side rows from state store earlier, even when the rows are still above watermark. However this needs more change in `SymmetricHashJoinStateManager`, so will leave this as a followup.

### Why are the changes needed?

Current stream-stream join supports inner, left outer and right outer join (https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala#L166 ). We do see internally a lot of users are using left semi stream-stream join (not spark structured streaming), e.g. I want to get the ad impression (join left side) which has click (joint right side), but I don't care how many clicks per ad (left semi semantics).

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added unit tests in `UnsupportedOperationChecker.scala` and `StreamingJoinSuite.scala`.

Closes #30076 from c21/stream-join.

Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Jungtaek Lim (HeartSaVioR) <kabhwan.opensource@gmail.com>
This commit is contained in:
Cheng Su 2020-10-26 13:33:06 +09:00 committed by Jungtaek Lim (HeartSaVioR)
parent 369cc614f3
commit d87a0bb2ca
6 changed files with 544 additions and 179 deletions

View file

@ -291,17 +291,17 @@ object UnsupportedOperationChecker extends Logging {
throwError("Full outer joins with streaming DataFrames/Datasets are not supported")
}
case LeftSemi | LeftAnti =>
case LeftAnti =>
if (right.isStreaming) {
throwError("Left semi/anti joins with a streaming DataFrame/Dataset " +
throwError("Left anti joins with a streaming DataFrame/Dataset " +
"on the right are not supported")
}
// We support streaming left outer joins with static on the right always, and with
// stream on both sides under the appropriate conditions.
case LeftOuter =>
// We support streaming left outer and left semi joins with static on the right always,
// and with stream on both sides under the appropriate conditions.
case LeftOuter | LeftSemi =>
if (!left.isStreaming && right.isStreaming) {
throwError("Left outer join with a streaming DataFrame/Dataset " +
throwError(s"$joinType join with a streaming DataFrame/Dataset " +
"on the right and a static DataFrame/Dataset on the left is not supported")
} else if (left.isStreaming && right.isStreaming) {
val watermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan)
@ -311,7 +311,8 @@ object UnsupportedOperationChecker extends Logging {
left.outputSet, right.outputSet, condition, Some(1000000)).isDefined
if (!watermarkInJoinKeys && !hasValidWatermarkRange) {
throwError("Stream-stream outer join between two streaming DataFrame/Datasets " +
throwError(
s"Stream-stream $joinType join between two streaming DataFrame/Datasets " +
"is not supported without a watermark in the join keys, or a watermark on " +
"the nullable side and an appropriate range condition")
}

View file

@ -55,6 +55,16 @@ class JoinedRow extends InternalRow {
this
}
/** Gets this JoinedRow's left base row. */
def getLeft: InternalRow = {
row1
}
/** Gets this JoinedRow's right base row. */
def getRight: InternalRow = {
row2
}
override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
assert(fieldTypes.length == row1.numFields + row2.numFields)
val (left, right) = fieldTypes.splitAt(row1.numFields)

View file

@ -490,7 +490,69 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
_.join(_, joinType = LeftSemi),
streamStreamSupported = false,
batchStreamSupported = false,
expectedMsg = "left semi/anti joins")
expectedMsg = "LeftSemi join")
// Left semi joins: update and complete mode not allowed
assertNotSupportedInStreamingPlan(
"left semi join with stream-stream relations and update mode",
streamRelation.join(streamRelation, joinType = LeftSemi,
condition = Some(attribute === attribute)),
OutputMode.Update(),
Seq("is not supported in Update output mode"))
assertNotSupportedInStreamingPlan(
"left semi join with stream-stream relations and complete mode",
Aggregate(Nil, aggExprs("d"), streamRelation.join(streamRelation, joinType = LeftSemi,
condition = Some(attribute === attribute))),
OutputMode.Complete(),
Seq("is not supported in Complete output mode"))
// Left semi joins: stream-stream allowed with join on watermark attribute
// Note that the attribute need not be watermarked on both sides.
assertSupportedInStreamingPlan(
"left semi join with stream-stream relations and join on attribute with left watermark",
streamRelation.join(streamRelation, joinType = LeftSemi,
condition = Some(attributeWithWatermark === attribute)),
OutputMode.Append())
assertSupportedInStreamingPlan(
"left semi join with stream-stream relations and join on attribute with right watermark",
streamRelation.join(streamRelation, joinType = LeftSemi,
condition = Some(attribute === attributeWithWatermark)),
OutputMode.Append())
assertNotSupportedInStreamingPlan(
"left semi join with stream-stream relations and join on non-watermarked attribute",
streamRelation.join(streamRelation, joinType = LeftSemi,
condition = Some(attribute === attribute)),
OutputMode.Append(),
Seq("without a watermark in the join keys"))
// Left semi joins: stream-stream allowed with range condition yielding state value watermark
assertSupportedInStreamingPlan(
"left semi join with stream-stream relations and state value watermark", {
val leftRelation = streamRelation
val rightTimeWithWatermark =
AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata)
val rightRelation = new TestStreamingRelation(rightTimeWithWatermark)
leftRelation.join(
rightRelation,
joinType = LeftSemi,
condition = Some(attribute > rightTimeWithWatermark + 10))
},
OutputMode.Append())
// Left semi joins: stream-stream not allowed with insufficient range condition
assertNotSupportedInStreamingPlan(
"left semi join with stream-stream relations and state value watermark", {
val leftRelation = streamRelation
val rightTimeWithWatermark =
AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata)
val rightRelation = new TestStreamingRelation(rightTimeWithWatermark)
leftRelation.join(
rightRelation,
joinType = LeftSemi,
condition = Some(attribute < rightTimeWithWatermark + 10))
},
OutputMode.Append(),
Seq("appropriate range condition"))
// Left anti joins: stream-* not allowed
testBinaryOperationInStreamingPlan(
@ -498,7 +560,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
_.join(_, joinType = LeftAnti),
streamStreamSupported = false,
batchStreamSupported = false,
expectedMsg = "left semi/anti joins")
expectedMsg = "Left anti join")
// Right outer joins: stream-* not allowed
testBinaryOperationInStreamingPlan(

View file

@ -152,7 +152,8 @@ case class StreamingSymmetricHashJoinExec(
}
if (stateFormatVersion < 2 && joinType != Inner) {
throw new IllegalArgumentException("The query is using stream-stream outer join with state" +
throw new IllegalArgumentException(
s"The query is using stream-stream $joinType join with state" +
s" format version ${stateFormatVersion} - correctness issue is discovered. Please discard" +
" the checkpoint and rerun the query. See SPARK-26154 for more details.")
}
@ -165,7 +166,7 @@ case class StreamingSymmetricHashJoinExec(
}
require(
joinType == Inner || joinType == LeftOuter || joinType == RightOuter,
joinType == Inner || joinType == LeftOuter || joinType == RightOuter || joinType == LeftSemi,
errorMessageForJoinType)
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType))
@ -185,6 +186,7 @@ case class StreamingSymmetricHashJoinExec(
case _: InnerLike => left.output ++ right.output
case LeftOuter => left.output ++ right.output.map(_.withNullability(true))
case RightOuter => left.output.map(_.withNullability(true)) ++ right.output
case LeftSemi => left.output
case _ => throwBadJoinTypeException()
}
@ -193,6 +195,7 @@ case class StreamingSymmetricHashJoinExec(
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case LeftSemi => left.outputPartitioning
case _ => throwBadJoinTypeException()
}
@ -246,14 +249,21 @@ case class StreamingSymmetricHashJoinExec(
// Join one side input using the other side's buffered/state rows. Here is how it is done.
//
// - `leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner)` generates all rows from
// matching new left input with stored right input, and also stores all the left input
// - `leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner)`
// - Inner, Left Outer, Right Outer Join: generates all rows from matching new left input
// with stored right input, and also stores all the left input.
// - Left Semi Join: generates all new left input rows from matching new left input with
// stored right input, and also stores all the non-matched left input.
//
// - `rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner)` generates all rows from
// matching new right input with stored left input, and also stores all the right input.
// It also generates all rows from matching new left input with new right input, since
// the new left input has become stored by that point. This tiny asymmetry is necessary
// to avoid duplication.
// - `rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner)`
// - Inner, Left Outer, Right Outer Join: generates all rows from matching new right input
// with stored left input, and also stores all the right input.
// It also generates all rows from matching new left input with new right input, since
// the new left input has become stored by that point. This tiny asymmetry is necessary
// to avoid duplication.
// - Left Semi Join: generates all stored left input rows, from matching new right input
// with stored left input, and also stores all the right input. Note only first-time
// matched left input rows will be generated, this is to guarantee left semi semantics.
val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) {
(input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched)
}
@ -261,22 +271,21 @@ case class StreamingSymmetricHashJoinExec(
(input: InternalRow, matched: InternalRow) => joinedRow.withLeft(matched).withRight(input)
}
// We need to save the time that the inner join output iterator completes, since outer join
// output counts as both update and removal time.
var innerOutputCompletionTimeNs: Long = 0
def onInnerOutputCompletion = {
innerOutputCompletionTimeNs = System.nanoTime
// We need to save the time that the one side hash join output iterator completes, since
// other join output counts as both update and removal time.
var hashJoinOutputCompletionTimeNs: Long = 0
def onHashJoinOutputCompletion(): Unit = {
hashJoinOutputCompletionTimeNs = System.nanoTime
}
// This is the iterator which produces the inner join rows. For outer joins, this will be
// prepended to a second iterator producing outer join rows; for inner joins, this is the full
// output.
val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]](
(leftOutputIter ++ rightOutputIter), onInnerOutputCompletion)
// This is the iterator which produces the inner and left semi join rows. For other joins,
// this will be prepended to a second iterator producing other rows; for inner and left semi
// joins, this is the full output.
val hashJoinOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]](
leftOutputIter ++ rightOutputIter, onHashJoinOutputCompletion())
val outputIter: Iterator[InternalRow] = joinType match {
case Inner =>
innerOutputIter
case Inner | LeftSemi =>
hashJoinOutputIter
case LeftOuter =>
// We generate the outer join input by:
// * Getting an iterator over the rows that have aged out on the left side. These rows are
@ -311,7 +320,7 @@ case class StreamingSymmetricHashJoinExec(
}
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
innerOutputIter ++ outerOutputIter
hashJoinOutputIter ++ outerOutputIter
case RightOuter =>
// See comments for left outer case.
def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = {
@ -330,11 +339,15 @@ case class StreamingSymmetricHashJoinExec(
}
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
innerOutputIter ++ outerOutputIter
hashJoinOutputIter ++ outerOutputIter
case _ => throwBadJoinTypeException()
}
val outputProjection = UnsafeProjection.create(left.output ++ right.output, output)
val outputProjection = if (joinType == LeftSemi) {
UnsafeProjection.create(output, output)
} else {
UnsafeProjection.create(left.output ++ right.output, output)
}
val outputIterWithMetrics = outputIter.map { row =>
numOutputRows += 1
outputProjection(row)
@ -345,24 +358,28 @@ case class StreamingSymmetricHashJoinExec(
// All processing time counts as update time.
allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0)
// Processing time between inner output completion and here comes from the outer portion of a
// join, and thus counts as removal time as we remove old state from one side while iterating.
if (innerOutputCompletionTimeNs != 0) {
// Processing time between one side hash join output completion and here comes from the
// outer portion of a join, and thus counts as removal time as we remove old state from
// one side while iterating.
if (hashJoinOutputCompletionTimeNs != 0) {
allRemovalsTimeMs +=
math.max(NANOSECONDS.toMillis(System.nanoTime - innerOutputCompletionTimeNs), 0)
math.max(NANOSECONDS.toMillis(System.nanoTime - hashJoinOutputCompletionTimeNs), 0)
}
allRemovalsTimeMs += timeTakenMs {
// Remove any remaining state rows which aren't needed because they're below the watermark.
//
// For inner joins, we have to remove unnecessary state rows from both sides if possible.
// For inner and left semi joins, we have to remove unnecessary state rows from both sides
// if possible.
//
// For outer joins, we have already removed unnecessary state rows from the outer side
// (e.g., left side for left outer join) while generating the outer "null" outputs. Now, we
// have to remove unnecessary state rows from the other side (e.g., right side for the left
// outer join) if possible. In all cases, nothing needs to be outputted, hence the removal
// needs to be done greedily by immediately consuming the returned iterator.
val cleanupIter = joinType match {
case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
case Inner | LeftSemi =>
leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
case LeftOuter => rightSideJoiner.removeOldState()
case RightOuter => leftSideJoiner.removeOldState()
case _ => throwBadJoinTypeException()
@ -481,6 +498,26 @@ case class StreamingSymmetricHashJoinExec(
case _ => (_: InternalRow) => Iterator.empty
}
val excludeRowsAlreadyMatched = joinType == LeftSemi && joinSide == RightSide
val generateOutputIter: (InternalRow, Iterator[JoinedRow]) => Iterator[InternalRow] =
joinSide match {
case LeftSide if joinType == LeftSemi =>
(input: InternalRow, joinedRowIter: Iterator[JoinedRow]) =>
// For left side of left semi join, generate one left row if there is matched
// rows from right side. Otherwise, generate nothing.
if (joinedRowIter.nonEmpty) {
Iterator.single(input)
} else {
Iterator.empty
}
case RightSide if joinType == LeftSemi =>
(_: InternalRow, joinedRowIter: Iterator[JoinedRow]) =>
// For right side of left semi join, generate matched left rows only.
joinedRowIter.map(_.getLeft)
case _ => (_: InternalRow, joinedRowIter: Iterator[JoinedRow]) => joinedRowIter
}
nonLateRows.flatMap { row =>
val thisRow = row.asInstanceOf[UnsafeRow]
// If this row fails the pre join filter, that means it can never satisfy the full join
@ -489,8 +526,12 @@ case class StreamingSymmetricHashJoinExec(
// the case of inner join).
if (preJoinFilter(thisRow)) {
val key = keyGenerator(thisRow)
val outputIter: Iterator[JoinedRow] = otherSideJoiner.joinStateManager
.getJoinedRows(key, thatRow => generateJoinedRow(thisRow, thatRow), postJoinFilter)
val joinedRowIter: Iterator[JoinedRow] = otherSideJoiner.joinStateManager.getJoinedRows(
key,
thatRow => generateJoinedRow(thisRow, thatRow),
postJoinFilter,
excludeRowsAlreadyMatched)
val outputIter = generateOutputIter(thisRow, joinedRowIter)
new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter)
} else {
generateFilteredJoinedRow(thisRow)
@ -501,13 +542,19 @@ case class StreamingSymmetricHashJoinExec(
private class AddingProcessedRowToStateCompletionIterator(
key: UnsafeRow,
thisRow: UnsafeRow,
subIter: Iterator[JoinedRow])
extends CompletionIterator[JoinedRow, Iterator[JoinedRow]](subIter) {
subIter: Iterator[InternalRow])
extends CompletionIterator[InternalRow, Iterator[InternalRow]](subIter) {
private val iteratorNotEmpty: Boolean = super.hasNext
override def completion(): Unit = {
val shouldAddToState = // add only if both removal predicates do not match
!stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow)
val isLeftSemiWithMatch =
joinType == LeftSemi && joinSide == LeftSide && iteratorNotEmpty
// Add to state store only if both removal predicates do not match,
// and the row is not matched for left side of left semi join.
val shouldAddToState =
!stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) &&
!isLeftSemiWithMatch
if (shouldAddToState) {
joinStateManager.append(key, thisRow, matched = iteratorNotEmpty)
updatedStateRowsCount += 1

View file

@ -99,13 +99,20 @@ class SymmetricHashJoinStateManager(
/**
* Get all the matched values for given join condition, with marking matched.
* This method is designed to mark joined rows properly without exposing internal index of row.
*
* @param excludeRowsAlreadyMatched Do not join with rows already matched previously.
* This is used for right side of left semi join in
* [[StreamingSymmetricHashJoinExec]] only.
*/
def getJoinedRows(
key: UnsafeRow,
generateJoinedRow: InternalRow => JoinedRow,
predicate: JoinedRow => Boolean): Iterator[JoinedRow] = {
predicate: JoinedRow => Boolean,
excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow] = {
val numValues = keyToNumValues.get(key)
keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue =>
keyWithIndexToValue.getAll(key, numValues).filterNot { keyIdxToValue =>
excludeRowsAlreadyMatched && keyIdxToValue.matched
}.map { keyIdxToValue =>
val joinedRow = generateJoinedRow(keyIdxToValue.value)
if (predicate(joinedRow)) {
if (!keyIdxToValue.matched) {

View file

@ -41,18 +41,174 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
abstract class StreamingJoinSuite
extends StreamTest with StateStoreMetricsTest with BeforeAndAfter {
class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter {
import testImplicits._
before {
SparkSession.setActiveSessionInternal(spark) // set this before force initializing 'joinExec'
spark.streams.stateStoreCoordinator // initialize the lazy coordinator
SparkSession.setActiveSessionInternal(spark) // set this before force initializing 'joinExec'
spark.streams.stateStoreCoordinator // initialize the lazy coordinator
}
after {
StateStore.stop()
}
protected def setupStream(prefix: String, multiplier: Int): (MemoryStream[Int], DataFrame) = {
val input = MemoryStream[Int]
val df = input.toDF
.select(
'value as "key",
timestamp_seconds($"value") as s"${prefix}Time",
('value * multiplier) as s"${prefix}Value")
.withWatermark(s"${prefix}Time", "10 seconds")
(input, df)
}
protected def setupWindowedJoin(joinType: String)
: (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
val (input1, df1) = setupStream("left", 2)
val (input2, df2) = setupStream("right", 3)
val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue)
val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue)
val joined = windowed1.join(windowed2, Seq("key", "window"), joinType)
val select = if (joinType == "left_semi") {
joined.select('key, $"window.end".cast("long"), 'leftValue)
} else {
joined.select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
}
(input1, input2, select)
}
protected def setupWindowedJoinWithLeftCondition(joinType: String)
: (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
val (leftInput, df1) = setupStream("left", 2)
val (rightInput, df2) = setupStream("right", 3)
// Use different schemas to ensure the null row is being generated from the correct side.
val left = df1.select('key, window('leftTime, "10 second"), 'leftValue)
val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string"))
val joined = left.join(
right,
left("key") === right("key")
&& left("window") === right("window")
&& 'leftValue > 4,
joinType)
val select = if (joinType == "left_semi") {
joined.select(left("key"), left("window.end").cast("long"), 'leftValue)
} else if (joinType == "left_outer") {
joined.select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue)
} else if (joinType == "right_outer") {
joined.select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue)
} else {
joined
}
(leftInput, rightInput, select)
}
protected def setupWindowedJoinWithRightCondition(joinType: String)
: (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
val (leftInput, df1) = setupStream("left", 2)
val (rightInput, df2) = setupStream("right", 3)
// Use different schemas to ensure the null row is being generated from the correct side.
val left = df1.select('key, window('leftTime, "10 second"), 'leftValue)
val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string"))
val joined = left.join(
right,
left("key") === right("key")
&& left("window") === right("window")
&& 'rightValue.cast("int") > 7,
joinType)
val select = if (joinType == "left_semi") {
joined.select(left("key"), left("window.end").cast("long"), 'leftValue)
} else if (joinType == "left_outer") {
joined.select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue)
} else if (joinType == "right_outer") {
joined.select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue)
} else {
joined
}
(leftInput, rightInput, select)
}
protected def setupWindowedJoinWithRangeCondition(joinType: String)
: (MemoryStream[(Int, Int)], MemoryStream[(Int, Int)], DataFrame) = {
val leftInput = MemoryStream[(Int, Int)]
val rightInput = MemoryStream[(Int, Int)]
val df1 = leftInput.toDF.toDF("leftKey", "time")
.select('leftKey, timestamp_seconds($"time") as "leftTime", ('leftKey * 2) as "leftValue")
.withWatermark("leftTime", "10 seconds")
val df2 = rightInput.toDF.toDF("rightKey", "time")
.select('rightKey, timestamp_seconds($"time") as "rightTime",
('rightKey * 3) as "rightValue")
.withWatermark("rightTime", "10 seconds")
val joined =
df1.join(
df2,
expr("leftKey = rightKey AND " +
"leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + interval 5 seconds"),
joinType)
val select = if (joinType == "left_semi") {
joined.select('leftKey, 'leftTime.cast("int"))
} else {
joined.select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int"))
}
(leftInput, rightInput, select)
}
protected def setupWindowedSelfJoin(joinType: String)
: (MemoryStream[(Int, Long)], DataFrame) = {
val inputStream = MemoryStream[(Int, Long)]
val df = inputStream.toDS()
.select(col("_1").as("value"), timestamp_seconds($"_2").as("timestamp"))
val leftStream = df.select(col("value").as("leftId"), col("timestamp").as("leftTime"))
val rightStream = df
// Introduce misses for ease of debugging
.where(col("value") % 2 === 0)
.select(col("value").as("rightId"), col("timestamp").as("rightTime"))
val joined = leftStream
.withWatermark("leftTime", "5 seconds")
.join(
rightStream.withWatermark("rightTime", "5 seconds"),
expr("leftId = rightId AND rightTime >= leftTime AND " +
"rightTime <= leftTime + interval 5 seconds"),
joinType)
val select = if (joinType == "left_semi") {
joined.select(col("leftId"), col("leftTime").cast("int"))
} else {
joined.select(col("leftId"), col("leftTime").cast("int"),
col("rightId"), col("rightTime").cast("int"))
}
(inputStream, select)
}
}
class StreamingInnerJoinSuite extends StreamingJoinSuite {
import testImplicits._
test("stream stream inner join on non-time column") {
val input1 = MemoryStream[Int]
@ -486,58 +642,13 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
}
class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter {
class StreamingOuterJoinSuite extends StreamingJoinSuite {
import testImplicits._
import org.apache.spark.sql.functions._
before {
SparkSession.setActiveSessionInternal(spark) // set this before force initializing 'joinExec'
spark.streams.stateStoreCoordinator // initialize the lazy coordinator
}
after {
StateStore.stop()
}
private def setupStream(prefix: String, multiplier: Int): (MemoryStream[Int], DataFrame) = {
val input = MemoryStream[Int]
val df = input.toDF
.select(
'value as "key",
timestamp_seconds($"value") as s"${prefix}Time",
('value * multiplier) as s"${prefix}Value")
.withWatermark(s"${prefix}Time", "10 seconds")
return (input, df)
}
private def setupWindowedJoin(joinType: String):
(MemoryStream[Int], MemoryStream[Int], DataFrame) = {
val (input1, df1) = setupStream("left", 2)
val (input2, df2) = setupStream("right", 3)
val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue)
val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue)
val joined = windowed1.join(windowed2, Seq("key", "window"), joinType)
.select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
(input1, input2, joined)
}
test("left outer early state exclusion on left") {
val (leftInput, df1) = setupStream("left", 2)
val (rightInput, df2) = setupStream("right", 3)
// Use different schemas to ensure the null row is being generated from the correct side.
val left = df1.select('key, window('leftTime, "10 second"), 'leftValue)
val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string"))
val joined = left.join(
right,
left("key") === right("key")
&& left("window") === right("window")
&& 'leftValue > 4,
"left_outer")
.select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue)
val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("left_outer")
testStream(joined)(
MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
@ -554,19 +665,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
}
test("left outer early state exclusion on right") {
val (leftInput, df1) = setupStream("left", 2)
val (rightInput, df2) = setupStream("right", 3)
// Use different schemas to ensure the null row is being generated from the correct side.
val left = df1.select('key, window('leftTime, "10 second"), 'leftValue)
val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string"))
val joined = left.join(
right,
left("key") === right("key")
&& left("window") === right("window")
&& 'rightValue.cast("int") > 7,
"left_outer")
.select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue)
val (leftInput, rightInput, joined) = setupWindowedJoinWithRightCondition("left_outer")
testStream(joined)(
MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
@ -583,19 +682,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
}
test("right outer early state exclusion on left") {
val (leftInput, df1) = setupStream("left", 2)
val (rightInput, df2) = setupStream("right", 3)
// Use different schemas to ensure the null row is being generated from the correct side.
val left = df1.select('key, window('leftTime, "10 second"), 'leftValue)
val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string"))
val joined = left.join(
right,
left("key") === right("key")
&& left("window") === right("window")
&& 'leftValue > 4,
"right_outer")
.select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue)
val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("right_outer")
testStream(joined)(
MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
@ -612,19 +699,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
}
test("right outer early state exclusion on right") {
val (leftInput, df1) = setupStream("left", 2)
val (rightInput, df2) = setupStream("right", 3)
// Use different schemas to ensure the null row is being generated from the correct side.
val left = df1.select('key, window('leftTime, "10 second"), 'leftValue)
val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string"))
val joined = left.join(
right,
left("key") === right("key")
&& left("window") === right("window")
&& 'rightValue.cast("int") > 7,
"right_outer")
.select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue)
val (leftInput, rightInput, joined) = setupWindowedJoinWithRightCondition("right_outer")
testStream(joined)(
MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
@ -681,27 +756,8 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
("right_outer", Row(null, 2, null, 5))
).foreach { case (joinType: String, outerResult) =>
test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range condition") {
import org.apache.spark.sql.functions._
val (leftInput, rightInput, joined) = setupWindowedJoinWithRangeCondition(joinType)
val leftInput = MemoryStream[(Int, Int)]
val rightInput = MemoryStream[(Int, Int)]
val df1 = leftInput.toDF.toDF("leftKey", "time")
.select('leftKey, timestamp_seconds($"time") as "leftTime", ('leftKey * 2) as "leftValue")
.withWatermark("leftTime", "10 seconds")
val df2 = rightInput.toDF.toDF("rightKey", "time")
.select('rightKey, timestamp_seconds($"time") as "rightTime",
('rightKey * 3) as "rightValue")
.withWatermark("rightTime", "10 seconds")
val joined =
df1.join(
df2,
expr("leftKey = rightKey AND " +
"leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + interval 5 seconds"),
joinType)
.select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int"))
testStream(joined)(
AddData(leftInput, (1, 5), (3, 5)),
CheckAnswer(),
@ -780,27 +836,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
}
test("SPARK-26187 self left outer join should not return outer nulls for already matched rows") {
val inputStream = MemoryStream[(Int, Long)]
val df = inputStream.toDS()
.select(col("_1").as("value"), timestamp_seconds($"_2").as("timestamp"))
val leftStream = df.select(col("value").as("leftId"), col("timestamp").as("leftTime"))
val rightStream = df
// Introduce misses for ease of debugging
.where(col("value") % 2 === 0)
.select(col("value").as("rightId"), col("timestamp").as("rightTime"))
val query = leftStream
.withWatermark("leftTime", "5 seconds")
.join(
rightStream.withWatermark("rightTime", "5 seconds"),
expr("leftId = rightId AND rightTime >= leftTime AND " +
"rightTime <= leftTime + interval 5 seconds"),
joinType = "leftOuter")
.select(col("leftId"), col("leftTime").cast("int"),
col("rightId"), col("rightTime").cast("int"))
val (inputStream, query) = setupWindowedSelfJoin("left_outer")
testStream(query)(
AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
@ -938,7 +974,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
throw writer.exception.get
}
assert(e.getMessage.toLowerCase(Locale.ROOT)
.contains("the query is using stream-stream outer join with state format version 1"))
.contains("the query is using stream-stream leftouter join with state format version 1"))
}
test("SPARK-29438: ensure UNION doesn't lead stream-stream join to use shifted partition IDs") {
@ -1041,3 +1077,205 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
)
}
}
class StreamingLeftSemiJoinSuite extends StreamingJoinSuite {
import testImplicits._
test("windowed left semi join") {
val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi")
testStream(joined)(
MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7),
CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)),
// states
// left: 1, 2, 3, 4 ,5
// right: 3, 4, 5, 6, 7
assertNumStateRows(total = 10, updated = 10),
MultiAddData(leftInput, 21)(rightInput, 22),
// Watermark = 11, should remove rows having window=[0,10].
CheckNewAnswer(),
// states
// left: 21
// right: 22
//
// states evicted
// left: 1, 2, 3, 4 ,5 (below watermark)
// right: 3, 4, 5, 6, 7 (below watermark)
assertNumStateRows(total = 2, updated = 2),
AddData(leftInput, 22),
CheckNewAnswer(Row(22, 30, 44)),
// Unlike inner/outer joins, given left input row matches with right input row,
// we don't buffer the matched left input row to the state store.
//
// states
// left: 21
// right: 22
assertNumStateRows(total = 2, updated = 0),
StopStream,
StartStream(),
AddData(leftInput, 1),
// Row not add as 1 < state key watermark = 12.
CheckNewAnswer(),
// states
// left: 21
// right: 22
assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1),
AddData(rightInput, 5),
// Row not add as 5 < state key watermark = 12.
CheckNewAnswer(),
// states
// left: 21
// right: 22
assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1)
)
}
test("left semi early state exclusion on left") {
val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("left_semi")
testStream(joined)(
MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
// The left rows with leftValue <= 4 should not generate their semi join rows and
// not get added to the state.
CheckNewAnswer(Row(3, 10, 6)),
// states
// left: 3
// right: 3, 4, 5
assertNumStateRows(total = 4, updated = 4),
// We shouldn't get more semi join rows when the watermark advances.
MultiAddData(leftInput, 20)(rightInput, 21),
CheckNewAnswer(),
// states
// left: 20
// right: 21
//
// states evicted
// left: 3 (below watermark)
// right: 3, 4, 5 (below watermark)
assertNumStateRows(total = 2, updated = 2),
AddData(rightInput, 20),
CheckNewAnswer((20, 30, 40)),
// states
// left: 20
// right: 21, 20
assertNumStateRows(total = 3, updated = 1)
)
}
test("left semi early state exclusion on right") {
val (leftInput, rightInput, joined) = setupWindowedJoinWithRightCondition("left_semi")
testStream(joined)(
MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
// The right rows with rightValue <= 7 should never be added to the state.
// The right row with rightValue = 9 > 7, hence joined and added to state.
CheckNewAnswer(Row(3, 10, 6)),
// states
// left: 3, 4, 5
// right: 3
assertNumStateRows(total = 4, updated = 4),
// We shouldn't get more semi join rows when the watermark advances.
MultiAddData(leftInput, 20)(rightInput, 21),
CheckNewAnswer(),
// states
// left: 20
// right: 21
//
// states evicted
// left: 3, 4, 5 (below watermark)
// right: 3 (below watermark)
assertNumStateRows(total = 2, updated = 2),
AddData(rightInput, 20),
CheckNewAnswer((20, 30, 40)),
// states
// left: 20
// right: 21, 20
assertNumStateRows(total = 3, updated = 1)
)
}
test("left semi join with watermark range condition") {
val (leftInput, rightInput, joined) = setupWindowedJoinWithRangeCondition("left_semi")
testStream(joined)(
AddData(leftInput, (1, 5), (3, 5)),
CheckNewAnswer(),
// states
// left: (1, 5), (3, 5)
// right: nothing
assertNumStateRows(total = 2, updated = 2),
AddData(rightInput, (1, 10), (2, 5)),
// Match left row in the state.
CheckNewAnswer((1, 5)),
// states
// left: (1, 5), (3, 5)
// right: (1, 10), (2, 5)
assertNumStateRows(total = 4, updated = 2),
AddData(rightInput, (1, 9)),
// No match as left row is already matched.
CheckNewAnswer(),
// states
// left: (1, 5), (3, 5)
// right: (1, 10), (2, 5), (1, 9)
assertNumStateRows(total = 5, updated = 1),
// Increase event time watermark to 20s by adding data with time = 30s on both inputs.
AddData(leftInput, (1, 7), (1, 30)),
CheckNewAnswer((1, 7)),
// states
// left: (1, 5), (3, 5), (1, 30)
// right: (1, 10), (2, 5), (1, 9)
assertNumStateRows(total = 6, updated = 1),
// Watermark = 30 - 10 = 20, no matched row.
AddData(rightInput, (0, 30)),
CheckNewAnswer(),
// states
// left: (1, 30)
// right: (0, 30)
//
// states evicted
// left: (1, 5), (3, 5) (below watermark = 20)
// right: (1, 10), (2, 5), (1, 9) (below watermark = 20)
assertNumStateRows(total = 2, updated = 1)
)
}
test("self left semi join") {
val (inputStream, query) = setupWindowedSelfJoin("left_semi")
testStream(query)(
AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
CheckNewAnswer((2, 2), (4, 4)),
// batch 1 - global watermark = 0
// states
// left: (2, 2L), (4, 4L)
// (left rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]])
// right: (2, 2L), (4, 4L)
// (right rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]])
assertNumStateRows(total = 4, updated = 4),
AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)),
CheckNewAnswer((6, 6), (8, 8), (10, 10)),
// batch 2 - global watermark = 5
// states
// left: (2, 2L), (4, 4L), (6, 6L), (8, 8L), (10, 10L)
// right: (6, 6L), (8, 8L), (10, 10L)
//
// states evicted
// left: nothing (it waits for 5 seconds more than watermark due to join condition)
// right: (2, 2L), (4, 4L)
assertNumStateRows(total = 8, updated = 6),
AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)),
CheckNewAnswer((12, 12), (14, 14)),
// batch 3 - global watermark = 9
// states
// left: (4, 4L), (6, 6L), (8, 8L), (10, 10L), (12, 12L), (14, 14L)
// right: (10, 10L), (12, 12L), (14, 14L)
//
// states evicted
// left: (2, 2L)
// right: (6, 6L), (8, 8L)
assertNumStateRows(total = 9, updated = 4)
)
}
}