[SPARK-25178][SQL] Directly ship the StructType objects of the keySchema / valueSchema for xxxHashMapGenerator

## What changes were proposed in this pull request?

This PR generates the code that to refer a `StructType` generated in the scala code instead of generating `StructType` in Java code.

The original code has two issues.
1. Avoid to used the field name such as `key.name`
1. Support complicated schema (e.g. nested DataType)

At first, [the JIRA entry](https://issues.apache.org/jira/browse/SPARK-25178) proposed to change the generated field name of the keySchema / valueSchema to a dummy name in `RowBasedHashMapGenerator` and `VectorizedHashMapGenerator.scala`. This proposal can addresse issue 1.

Ueshin suggested an approach to refer to a `StructType` generated in the scala code using `ctx.addReferenceObj()`. This approach can address issues 1 and 2. Finally, this PR uses this approach.

## How was this patch tested?

Existing UTs

Closes #22187 from kiszk/SPARK-25178.

Authored-by: Kazuaki Ishizaki <ishizaki@jp.ibm.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
Kazuaki Ishizaki 2018-08-24 14:58:55 +09:00 committed by Takuya UESHIN
parent 9b6baeb7b9
commit ab33028957
2 changed files with 10 additions and 60 deletions

View file

@ -44,31 +44,8 @@ class RowBasedHashMapGenerator(
groupingKeySchema, bufferSchema) {
override protected def initializeAggregateHashMap(): String = {
val generatedKeySchema: String =
s"new org.apache.spark.sql.types.StructType()" +
groupingKeySchema.map { key =>
val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
val generatedValueSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema)
val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema)
s"""
| private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
@ -78,8 +55,6 @@ class RowBasedHashMapGenerator(
| private int numBuckets = (int) (capacity / loadFactor);
| private int maxSteps = 2;
| private int numRows = 0;
| private org.apache.spark.sql.types.StructType keySchema = $generatedKeySchema
| private org.apache.spark.sql.types.StructType valueSchema = $generatedValueSchema
| private Object emptyVBase;
| private long emptyVOff;
| private int emptyVLen;
@ -90,9 +65,9 @@ class RowBasedHashMapGenerator(
| org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
| InternalRow emptyAggregationBuffer) {
| batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
| .allocate(keySchema, valueSchema, taskMemoryManager, capacity);
| .allocate($keySchema, $valueSchema, taskMemoryManager, capacity);
|
| final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);
| final UnsafeProjection valueProjection = UnsafeProjection.create($valueSchema);
| final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
|
| emptyVBase = emptyBuffer;

View file

@ -52,31 +52,9 @@ class VectorizedHashMapGenerator(
groupingKeySchema, bufferSchema) {
override protected def initializeAggregateHashMap(): String = {
val generatedSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
(groupingKeySchema ++ bufferSchema).map { key =>
val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
val generatedAggBufferSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
val schemaStructType = new StructType((groupingKeySchema ++ bufferSchema).toArray)
val schema = ctx.addReferenceObj("schemaTerm", schemaStructType)
val aggBufferSchemaFieldsLength = bufferSchema.fields.length
s"""
| private ${classOf[OnHeapColumnVector].getName}[] vectors;
@ -88,18 +66,15 @@ class VectorizedHashMapGenerator(
| private int numBuckets = (int) (capacity / loadFactor);
| private int maxSteps = 2;
| private int numRows = 0;
| private org.apache.spark.sql.types.StructType schema = $generatedSchema
| private org.apache.spark.sql.types.StructType aggregateBufferSchema =
| $generatedAggBufferSchema
|
| public $generatedClassName() {
| vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema);
| vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, $schema);
| batch = new ${classOf[ColumnarBatch].getName}(vectors);
|
| // Generates a projection to return the aggregate buffer only.
| ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors =
| new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length];
| for (int i = 0; i < aggregateBufferSchema.fields().length; i++) {
| new ${classOf[OnHeapColumnVector].getName}[$aggBufferSchemaFieldsLength];
| for (int i = 0; i < $aggBufferSchemaFieldsLength; i++) {
| aggBufferVectors[i] = vectors[i + ${groupingKeys.length}];
| }
| aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors);