[SPARK-11393] [SQL] CoGroupedIterator should respect the fact that GroupedIterator.hasNext is not idempotent

When we cogroup 2 `GroupedIterator`s in `CoGroupedIterator`, if the right side is smaller, we will consume right data and keep the left data unchanged. Then we call `hasNext` which will call `left.hasNext`. This will make `GroupedIterator` generate an extra group as the previous one has not been comsumed yet.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9346 from cloud-fan/cogroup and squashes the following commits:

9be67c8 [Wenchen Fan] SPARK-11393
This commit is contained in:
Wenchen Fan 2015-10-30 12:17:51 +01:00 committed by Michael Armbrust
parent 59db9e9c38
commit 14d08b9908
2 changed files with 32 additions and 6 deletions

View file

@ -38,17 +38,19 @@ class CoGroupedIterator(
private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
override def hasNext: Boolean = left.hasNext || right.hasNext
override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
if (currentLeftData.eq(null) && left.hasNext) {
override def hasNext: Boolean = {
if (currentLeftData == null && left.hasNext) {
currentLeftData = left.next()
}
if (currentRightData.eq(null) && right.hasNext) {
if (currentRightData == null && right.hasNext) {
currentRightData = right.next()
}
assert(currentLeftData.ne(null) || currentRightData.ne(null))
currentLeftData != null || currentRightData != null
}
override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
assert(hasNext)
if (currentLeftData.eq(null)) {
// left is null, right is not null, consume the right data.

View file

@ -48,4 +48,28 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
Nil
)
}
test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") {
val leftInput = Seq(create_row(2, "a")).iterator
val rightInput = Seq(create_row(1, 2L)).iterator
val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
val result = cogrouped.map {
case (key, leftData, rightData) =>
assert(key.numFields == 1)
(key.getInt(0), leftData.toSeq, rightData.toSeq)
}.toSeq
assert(result ==
(1,
Seq.empty,
Seq(create_row(1, 2L))) ::
(2,
Seq(create_row(2, "a")),
Seq.empty) ::
Nil
)
}
}