[SPARK-25829][SQL][FOLLOWUP] Refactor MapConcat in order to check properly the limit size
## What changes were proposed in this pull request? The PR starts from the [comment](https://github.com/apache/spark/pull/23124#discussion_r236112390) in the main one and it aims at: - simplifying the code for `MapConcat`; - be more precise in checking the limit size. ## How was this patch tested? existing tests Closes #23217 from mgaido91/SPARK-25829_followup. Authored-by: Marco Gaido <marcogaido91@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
180f969c97
commit
7143e9d722
|
@ -554,13 +554,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|
|||
return null
|
||||
}
|
||||
|
||||
val numElements = maps.foldLeft(0L)((sum, ad) => sum + ad.numElements())
|
||||
if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||
throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " +
|
||||
s"elements due to exceeding the map size limit " +
|
||||
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
|
||||
}
|
||||
|
||||
for (map <- maps) {
|
||||
mapBuilder.putAll(map.keyArray(), map.valueArray())
|
||||
}
|
||||
|
@ -569,8 +562,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val mapCodes = children.map(_.genCode(ctx))
|
||||
val keyType = dataType.keyType
|
||||
val valueType = dataType.valueType
|
||||
val argsName = ctx.freshName("args")
|
||||
val hasNullName = ctx.freshName("hasNull")
|
||||
val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder)
|
||||
|
@ -610,41 +601,12 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|
|||
)
|
||||
|
||||
val idxName = ctx.freshName("idx")
|
||||
val numElementsName = ctx.freshName("numElems")
|
||||
val finKeysName = ctx.freshName("finalKeys")
|
||||
val finValsName = ctx.freshName("finalValues")
|
||||
|
||||
val keyConcat = genCodeForArrays(ctx, keyType, false)
|
||||
|
||||
val valueConcat =
|
||||
if (valueType.sameType(keyType) &&
|
||||
!(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) {
|
||||
keyConcat
|
||||
} else {
|
||||
genCodeForArrays(ctx, valueType, dataType.valueContainsNull)
|
||||
}
|
||||
|
||||
val keyArgsName = ctx.freshName("keyArgs")
|
||||
val valArgsName = ctx.freshName("valArgs")
|
||||
|
||||
val mapMerge =
|
||||
s"""
|
||||
|ArrayData[] $keyArgsName = new ArrayData[${mapCodes.size}];
|
||||
|ArrayData[] $valArgsName = new ArrayData[${mapCodes.size}];
|
||||
|long $numElementsName = 0;
|
||||
|for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
|
||||
| $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
|
||||
| $valArgsName[$idxName] = $argsName[$idxName].valueArray();
|
||||
| $numElementsName += $argsName[$idxName].numElements();
|
||||
| $builderTerm.putAll($argsName[$idxName].keyArray(), $argsName[$idxName].valueArray());
|
||||
|}
|
||||
|if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
|
||||
| throw new RuntimeException("Unsuccessful attempt to concat maps with " +
|
||||
| $numElementsName + " elements due to exceeding the map size limit " +
|
||||
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|
||||
|}
|
||||
|ArrayData $finKeysName = $keyConcat($keyArgsName, (int) $numElementsName);
|
||||
|ArrayData $finValsName = $valueConcat($valArgsName, (int) $numElementsName);
|
||||
|${ev.value} = $builderTerm.from($finKeysName, $finValsName);
|
||||
|${ev.value} = $builderTerm.build();
|
||||
""".stripMargin
|
||||
|
||||
ev.copy(
|
||||
|
@ -660,41 +622,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|
|||
""".stripMargin)
|
||||
}
|
||||
|
||||
private def genCodeForArrays(
|
||||
ctx: CodegenContext,
|
||||
elementType: DataType,
|
||||
checkForNull: Boolean): String = {
|
||||
val counter = ctx.freshName("counter")
|
||||
val arrayData = ctx.freshName("arrayData")
|
||||
val argsName = ctx.freshName("args")
|
||||
val numElemName = ctx.freshName("numElements")
|
||||
val y = ctx.freshName("y")
|
||||
val z = ctx.freshName("z")
|
||||
|
||||
val allocation = CodeGenerator.createArrayData(
|
||||
arrayData, elementType, numElemName, s" $prettyName failed.")
|
||||
val assignment = CodeGenerator.createArrayAssignment(
|
||||
arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull)
|
||||
|
||||
val concat = ctx.freshName("concat")
|
||||
val concatDef =
|
||||
s"""
|
||||
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
|
||||
| $allocation
|
||||
| int $counter = 0;
|
||||
| for (int $y = 0; $y < ${children.length}; $y++) {
|
||||
| for (int $z = 0; $z < $argsName[$y].numElements(); $z++) {
|
||||
| $assignment
|
||||
| $counter++;
|
||||
| }
|
||||
| }
|
||||
| return $arrayData;
|
||||
|}
|
||||
""".stripMargin
|
||||
|
||||
ctx.addNewFunction(concat, concatDef)
|
||||
}
|
||||
|
||||
override def prettyName: String = "map_concat"
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import scala.collection.mutable
|
|||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods
|
||||
|
||||
/**
|
||||
* A builder of [[ArrayBasedMapData]], which fails if a null map key is detected, and removes
|
||||
|
@ -54,6 +55,10 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
|
|||
|
||||
val index = keyToIndex.getOrDefault(key, -1)
|
||||
if (index == -1) {
|
||||
if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||
throw new RuntimeException(s"Unsuccessful attempt to build maps with $size elements " +
|
||||
s"due to exceeding the map size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
|
||||
}
|
||||
keyToIndex.put(key, values.length)
|
||||
keys.append(key)
|
||||
values.append(value)
|
||||
|
@ -117,4 +122,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
|
|||
build()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the current size of the map which is going to be produced by the current builder.
|
||||
*/
|
||||
def size: Int = keys.size
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue