[SPARK-22385][SQL] MapObjects should not access list element by index
## What changes were proposed in this pull request? This issue was discovered and investigated by Ohad Raviv and Sean Owen in https://issues.apache.org/jira/browse/SPARK-21657. The input data of `MapObjects` may be a `List` which has O(n) complexity for accessing by index. When converting input data to catalyst array, `MapObjects` gets element by index in each loop, and results to bad performance. This PR fixes this issue by accessing elements via Iterator. ## How was this patch tested? using the test script in https://issues.apache.org/jira/browse/SPARK-21657 ``` val BASE = 100000000 val N = 100000 val df = sc.parallelize(List(("1234567890", (BASE to (BASE+N)).map(x => (x.toString, (x+1).toString, (x+2).toString, (x+3).toString)).toList ))).toDF("c1", "c_arr") spark.time(df.queryExecution.toRdd.foreach(_ => ())) ``` We can see 50x speed up. Author: Wenchen Fan <wenchen@databricks.com> Closes #19603 from cloud-fan/map-objects.
This commit is contained in:
parent
9f5c77ae32
commit
9f02d7dc53
|
@ -591,18 +591,43 @@ case class MapObjects private(
|
|||
case _ => inputData.dataType
|
||||
}
|
||||
|
||||
val (getLength, getLoopVar) = inputDataType match {
|
||||
// `MapObjects` generates a while loop to traverse the elements of the input collection. We
|
||||
// need to take care of Seq and List because they may have O(n) complexity for indexed accessing
|
||||
// like `list.get(1)`. Here we use Iterator to traverse Seq and List.
|
||||
val (getLength, prepareLoop, getLoopVar) = inputDataType match {
|
||||
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
|
||||
s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
|
||||
val it = ctx.freshName("it")
|
||||
(
|
||||
s"${genInputData.value}.size()",
|
||||
s"scala.collection.Iterator $it = ${genInputData.value}.toIterator();",
|
||||
s"$it.next()"
|
||||
)
|
||||
case ObjectType(cls) if cls.isArray =>
|
||||
s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]"
|
||||
(
|
||||
s"${genInputData.value}.length",
|
||||
"",
|
||||
s"${genInputData.value}[$loopIndex]"
|
||||
)
|
||||
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
|
||||
s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)"
|
||||
val it = ctx.freshName("it")
|
||||
(
|
||||
s"${genInputData.value}.size()",
|
||||
s"java.util.Iterator $it = ${genInputData.value}.iterator();",
|
||||
s"$it.next()"
|
||||
)
|
||||
case ArrayType(et, _) =>
|
||||
s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex)
|
||||
(
|
||||
s"${genInputData.value}.numElements()",
|
||||
"",
|
||||
ctx.getValue(genInputData.value, et, loopIndex)
|
||||
)
|
||||
case ObjectType(cls) if cls == classOf[Object] =>
|
||||
s"$seq == null ? $array.length : $seq.size()" ->
|
||||
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
|
||||
val it = ctx.freshName("it")
|
||||
(
|
||||
s"$seq == null ? $array.length : $seq.size()",
|
||||
s"scala.collection.Iterator $it = $seq == null ? null : $seq.toIterator();",
|
||||
s"$it == null ? $array[$loopIndex] : $it.next()"
|
||||
)
|
||||
}
|
||||
|
||||
// Make a copy of the data if it's unsafe-backed
|
||||
|
@ -676,6 +701,7 @@ case class MapObjects private(
|
|||
$initCollection
|
||||
|
||||
int $loopIndex = 0;
|
||||
$prepareLoop
|
||||
while ($loopIndex < $dataLength) {
|
||||
$loopValue = ($elementJavaType) ($getLoopVar);
|
||||
$loopNullCheck
|
||||
|
|
Loading…
Reference in a new issue