diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index d5508275c4..ca59bb145f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -44,31 +44,8 @@ class RowBasedHashMapGenerator( groupingKeySchema, bufferSchema) { override protected def initializeAggregateHashMap(): String = { - val generatedKeySchema: String = - s"new org.apache.spark.sql.types.StructType()" + - groupingKeySchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") - - val generatedValueSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - bufferSchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema) + val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema) s""" | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; @@ -78,8 +55,6 @@ class RowBasedHashMapGenerator( | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType keySchema = $generatedKeySchema - | private org.apache.spark.sql.types.StructType valueSchema = $generatedValueSchema | private Object emptyVBase; | private long emptyVOff; | private int emptyVLen; @@ -90,9 +65,9 @@ class RowBasedHashMapGenerator( | org.apache.spark.memory.TaskMemoryManager taskMemoryManager, | InternalRow emptyAggregationBuffer) { | batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch - | .allocate(keySchema, valueSchema, taskMemoryManager, capacity); + | .allocate($keySchema, $valueSchema, taskMemoryManager, capacity); | - | final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); + | final UnsafeProjection valueProjection = UnsafeProjection.create($valueSchema); | final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); | | emptyVBase = emptyBuffer; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 7b3580cecc..95ebefed08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -52,31 +52,9 @@ class VectorizedHashMapGenerator( groupingKeySchema, bufferSchema) { override protected def initializeAggregateHashMap(): String = { - val generatedSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - (groupingKeySchema ++ bufferSchema).map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") - - val generatedAggBufferSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - bufferSchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val schemaStructType = new StructType((groupingKeySchema ++ bufferSchema).toArray) + val schema = ctx.addReferenceObj("schemaTerm", schemaStructType) + val aggBufferSchemaFieldsLength = bufferSchema.fields.length s""" | private ${classOf[OnHeapColumnVector].getName}[] vectors; @@ -88,18 +66,15 @@ class VectorizedHashMapGenerator( | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType schema = $generatedSchema - | private org.apache.spark.sql.types.StructType aggregateBufferSchema = - | $generatedAggBufferSchema | | public $generatedClassName() { - | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); + | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, $schema); | batch = new ${classOf[ColumnarBatch].getName}(vectors); | | // Generates a projection to return the aggregate buffer only. | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = - | new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length]; - | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) { + | new ${classOf[OnHeapColumnVector].getName}[$aggBufferSchemaFieldsLength]; + | for (int i = 0; i < $aggBufferSchemaFieldsLength; i++) { | aggBufferVectors[i] = vectors[i + ${groupingKeys.length}]; | } | aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors);