diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index bedfa9c472..202cbd0f4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -243,28 +243,11 @@ case class MapPartitionsInRWithArrowExec( // binary in a batch due to the limitation of R API. See also ARROW-4512. val columnarBatchIter = runner.compute(batchIter, -1) val outputProject = UnsafeProjection.create(output, output) - new Iterator[InternalRow] { - - private var currentIter = if (columnarBatchIter.hasNext) { - val batch = columnarBatchIter.next() - val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) - assert(outputTypes == actualDataTypes, "Invalid schema from dapply(): " + - s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") - batch.rowIterator.asScala - } else { - Iterator.empty - } - - override def hasNext: Boolean = currentIter.hasNext || { - if (columnarBatchIter.hasNext) { - currentIter = columnarBatchIter.next().rowIterator.asScala - hasNext - } else { - false - } - } - - override def next(): InternalRow = currentIter.next() + columnarBatchIter.flatMap { batch => + val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) + assert(outputTypes == actualDataTypes, "Invalid schema from dapply(): " + + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") + batch.rowIterator.asScala }.map(outputProject) } }