[SPARK-25829][SQL][FOLLOWUP] Refactor MapConcat in order to check properly the limit size

## What changes were proposed in this pull request?

The PR starts from the [comment](https://github.com/apache/spark/pull/23124#discussion_r236112390) in the main one and it aims at:
 - simplifying the code for `MapConcat`;
 - be more precise in checking the limit size.

## How was this patch tested?

existing tests

Closes #23217 from mgaido91/SPARK-25829_followup.

Authored-by: Marco Gaido <marcogaido91@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Marco Gaido 2018-12-05 09:12:24 +08:00 committed by Wenchen Fan
parent 180f969c97
commit 7143e9d722
2 changed files with 12 additions and 75 deletions

View file

@ -554,13 +554,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
return null
}
val numElements = maps.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " +
s"elements due to exceeding the map size limit " +
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
for (map <- maps) {
mapBuilder.putAll(map.keyArray(), map.valueArray())
}
@ -569,8 +562,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val mapCodes = children.map(_.genCode(ctx))
val keyType = dataType.keyType
val valueType = dataType.valueType
val argsName = ctx.freshName("args")
val hasNullName = ctx.freshName("hasNull")
val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder)
@ -610,41 +601,12 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
)
val idxName = ctx.freshName("idx")
val numElementsName = ctx.freshName("numElems")
val finKeysName = ctx.freshName("finalKeys")
val finValsName = ctx.freshName("finalValues")
val keyConcat = genCodeForArrays(ctx, keyType, false)
val valueConcat =
if (valueType.sameType(keyType) &&
!(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) {
keyConcat
} else {
genCodeForArrays(ctx, valueType, dataType.valueContainsNull)
}
val keyArgsName = ctx.freshName("keyArgs")
val valArgsName = ctx.freshName("valArgs")
val mapMerge =
s"""
|ArrayData[] $keyArgsName = new ArrayData[${mapCodes.size}];
|ArrayData[] $valArgsName = new ArrayData[${mapCodes.size}];
|long $numElementsName = 0;
|for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
| $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
| $valArgsName[$idxName] = $argsName[$idxName].valueArray();
| $numElementsName += $argsName[$idxName].numElements();
| $builderTerm.putAll($argsName[$idxName].keyArray(), $argsName[$idxName].valueArray());
|}
|if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful attempt to concat maps with " +
| $numElementsName + " elements due to exceeding the map size limit " +
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|}
|ArrayData $finKeysName = $keyConcat($keyArgsName, (int) $numElementsName);
|ArrayData $finValsName = $valueConcat($valArgsName, (int) $numElementsName);
|${ev.value} = $builderTerm.from($finKeysName, $finValsName);
|${ev.value} = $builderTerm.build();
""".stripMargin
ev.copy(
@ -660,41 +622,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
""".stripMargin)
}
private def genCodeForArrays(
ctx: CodegenContext,
elementType: DataType,
checkForNull: Boolean): String = {
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")
val argsName = ctx.freshName("args")
val numElemName = ctx.freshName("numElements")
val y = ctx.freshName("y")
val z = ctx.freshName("z")
val allocation = CodeGenerator.createArrayData(
arrayData, elementType, numElemName, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull)
val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
| $allocation
| int $counter = 0;
| for (int $y = 0; $y < ${children.length}; $y++) {
| for (int $z = 0; $z < $argsName[$y].numElements(); $z++) {
| $assignment
| $counter++;
| }
| }
| return $arrayData;
|}
""".stripMargin
ctx.addNewFunction(concat, concatDef)
}
override def prettyName: String = "map_concat"
}

View file

@ -21,6 +21,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods
/**
* A builder of [[ArrayBasedMapData]], which fails if a null map key is detected, and removes
@ -54,6 +55,10 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
val index = keyToIndex.getOrDefault(key, -1)
if (index == -1) {
if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful attempt to build maps with $size elements " +
s"due to exceeding the map size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
keyToIndex.put(key, values.length)
keys.append(key)
values.append(value)
@ -117,4 +122,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
build()
}
}
/**
* Returns the current size of the map which is going to be produced by the current builder.
*/
def size: Int = keys.size
}