[SPARK-25178][SQL] Directly ship the StructType objects of the keySchema / valueSchema for xxxHashMapGenerator
## What changes were proposed in this pull request? This PR generates the code that to refer a `StructType` generated in the scala code instead of generating `StructType` in Java code. The original code has two issues. 1. Avoid to used the field name such as `key.name` 1. Support complicated schema (e.g. nested DataType) At first, [the JIRA entry](https://issues.apache.org/jira/browse/SPARK-25178) proposed to change the generated field name of the keySchema / valueSchema to a dummy name in `RowBasedHashMapGenerator` and `VectorizedHashMapGenerator.scala`. This proposal can addresse issue 1. Ueshin suggested an approach to refer to a `StructType` generated in the scala code using `ctx.addReferenceObj()`. This approach can address issues 1 and 2. Finally, this PR uses this approach. ## How was this patch tested? Existing UTs Closes #22187 from kiszk/SPARK-25178. Authored-by: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
parent
9b6baeb7b9
commit
ab33028957
|
@ -44,31 +44,8 @@ class RowBasedHashMapGenerator(
|
|||
groupingKeySchema, bufferSchema) {
|
||||
|
||||
override protected def initializeAggregateHashMap(): String = {
|
||||
val generatedKeySchema: String =
|
||||
s"new org.apache.spark.sql.types.StructType()" +
|
||||
groupingKeySchema.map { key =>
|
||||
val keyName = ctx.addReferenceObj("keyName", key.name)
|
||||
key.dataType match {
|
||||
case d: DecimalType =>
|
||||
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|
||||
|${d.precision}, ${d.scale}))""".stripMargin
|
||||
case _ =>
|
||||
s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
|
||||
}
|
||||
}.mkString("\n").concat(";")
|
||||
|
||||
val generatedValueSchema: String =
|
||||
s"new org.apache.spark.sql.types.StructType()" +
|
||||
bufferSchema.map { key =>
|
||||
val keyName = ctx.addReferenceObj("keyName", key.name)
|
||||
key.dataType match {
|
||||
case d: DecimalType =>
|
||||
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|
||||
|${d.precision}, ${d.scale}))""".stripMargin
|
||||
case _ =>
|
||||
s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
|
||||
}
|
||||
}.mkString("\n").concat(";")
|
||||
val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema)
|
||||
val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema)
|
||||
|
||||
s"""
|
||||
| private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
|
||||
|
@ -78,8 +55,6 @@ class RowBasedHashMapGenerator(
|
|||
| private int numBuckets = (int) (capacity / loadFactor);
|
||||
| private int maxSteps = 2;
|
||||
| private int numRows = 0;
|
||||
| private org.apache.spark.sql.types.StructType keySchema = $generatedKeySchema
|
||||
| private org.apache.spark.sql.types.StructType valueSchema = $generatedValueSchema
|
||||
| private Object emptyVBase;
|
||||
| private long emptyVOff;
|
||||
| private int emptyVLen;
|
||||
|
@ -90,9 +65,9 @@ class RowBasedHashMapGenerator(
|
|||
| org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
|
||||
| InternalRow emptyAggregationBuffer) {
|
||||
| batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
|
||||
| .allocate(keySchema, valueSchema, taskMemoryManager, capacity);
|
||||
| .allocate($keySchema, $valueSchema, taskMemoryManager, capacity);
|
||||
|
|
||||
| final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);
|
||||
| final UnsafeProjection valueProjection = UnsafeProjection.create($valueSchema);
|
||||
| final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
|
||||
|
|
||||
| emptyVBase = emptyBuffer;
|
||||
|
|
|
@ -52,31 +52,9 @@ class VectorizedHashMapGenerator(
|
|||
groupingKeySchema, bufferSchema) {
|
||||
|
||||
override protected def initializeAggregateHashMap(): String = {
|
||||
val generatedSchema: String =
|
||||
s"new org.apache.spark.sql.types.StructType()" +
|
||||
(groupingKeySchema ++ bufferSchema).map { key =>
|
||||
val keyName = ctx.addReferenceObj("keyName", key.name)
|
||||
key.dataType match {
|
||||
case d: DecimalType =>
|
||||
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|
||||
|${d.precision}, ${d.scale}))""".stripMargin
|
||||
case _ =>
|
||||
s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
|
||||
}
|
||||
}.mkString("\n").concat(";")
|
||||
|
||||
val generatedAggBufferSchema: String =
|
||||
s"new org.apache.spark.sql.types.StructType()" +
|
||||
bufferSchema.map { key =>
|
||||
val keyName = ctx.addReferenceObj("keyName", key.name)
|
||||
key.dataType match {
|
||||
case d: DecimalType =>
|
||||
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|
||||
|${d.precision}, ${d.scale}))""".stripMargin
|
||||
case _ =>
|
||||
s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
|
||||
}
|
||||
}.mkString("\n").concat(";")
|
||||
val schemaStructType = new StructType((groupingKeySchema ++ bufferSchema).toArray)
|
||||
val schema = ctx.addReferenceObj("schemaTerm", schemaStructType)
|
||||
val aggBufferSchemaFieldsLength = bufferSchema.fields.length
|
||||
|
||||
s"""
|
||||
| private ${classOf[OnHeapColumnVector].getName}[] vectors;
|
||||
|
@ -88,18 +66,15 @@ class VectorizedHashMapGenerator(
|
|||
| private int numBuckets = (int) (capacity / loadFactor);
|
||||
| private int maxSteps = 2;
|
||||
| private int numRows = 0;
|
||||
| private org.apache.spark.sql.types.StructType schema = $generatedSchema
|
||||
| private org.apache.spark.sql.types.StructType aggregateBufferSchema =
|
||||
| $generatedAggBufferSchema
|
||||
|
|
||||
| public $generatedClassName() {
|
||||
| vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema);
|
||||
| vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, $schema);
|
||||
| batch = new ${classOf[ColumnarBatch].getName}(vectors);
|
||||
|
|
||||
| // Generates a projection to return the aggregate buffer only.
|
||||
| ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors =
|
||||
| new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length];
|
||||
| for (int i = 0; i < aggregateBufferSchema.fields().length; i++) {
|
||||
| new ${classOf[OnHeapColumnVector].getName}[$aggBufferSchemaFieldsLength];
|
||||
| for (int i = 0; i < $aggBufferSchemaFieldsLength; i++) {
|
||||
| aggBufferVectors[i] = vectors[i + ${groupingKeys.length}];
|
||||
| }
|
||||
| aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors);
|
||||
|
|
Loading…
Reference in a new issue