[SPARK-11393] [SQL] CoGroupedIterator should respect the fact that GroupedIterator.hasNext is not idempotent
When we cogroup 2 `GroupedIterator`s in `CoGroupedIterator`, if the right side is smaller, we will consume right data and keep the left data unchanged. Then we call `hasNext` which will call `left.hasNext`. This will make `GroupedIterator` generate an extra group as the previous one has not been comsumed yet. Author: Wenchen Fan <wenchen@databricks.com> Closes #9346 from cloud-fan/cogroup and squashes the following commits: 9be67c8 [Wenchen Fan] SPARK-11393
This commit is contained in:
parent
59db9e9c38
commit
14d08b9908
|
@ -38,17 +38,19 @@ class CoGroupedIterator(
|
|||
private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
|
||||
private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
|
||||
|
||||
override def hasNext: Boolean = left.hasNext || right.hasNext
|
||||
|
||||
override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
|
||||
if (currentLeftData.eq(null) && left.hasNext) {
|
||||
override def hasNext: Boolean = {
|
||||
if (currentLeftData == null && left.hasNext) {
|
||||
currentLeftData = left.next()
|
||||
}
|
||||
if (currentRightData.eq(null) && right.hasNext) {
|
||||
if (currentRightData == null && right.hasNext) {
|
||||
currentRightData = right.next()
|
||||
}
|
||||
|
||||
assert(currentLeftData.ne(null) || currentRightData.ne(null))
|
||||
currentLeftData != null || currentRightData != null
|
||||
}
|
||||
|
||||
override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
|
||||
assert(hasNext)
|
||||
|
||||
if (currentLeftData.eq(null)) {
|
||||
// left is null, right is not null, consume the right data.
|
||||
|
|
|
@ -48,4 +48,28 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
Nil
|
||||
)
|
||||
}
|
||||
|
||||
test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") {
|
||||
val leftInput = Seq(create_row(2, "a")).iterator
|
||||
val rightInput = Seq(create_row(1, 2L)).iterator
|
||||
val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
|
||||
val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
|
||||
val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
|
||||
|
||||
val result = cogrouped.map {
|
||||
case (key, leftData, rightData) =>
|
||||
assert(key.numFields == 1)
|
||||
(key.getInt(0), leftData.toSeq, rightData.toSeq)
|
||||
}.toSeq
|
||||
|
||||
assert(result ==
|
||||
(1,
|
||||
Seq.empty,
|
||||
Seq(create_row(1, 2L))) ::
|
||||
(2,
|
||||
Seq(create_row(2, "a")),
|
||||
Seq.empty) ::
|
||||
Nil
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue