[SPARK-22652][SQL] remove set methods in ColumnarRow
## What changes were proposed in this pull request? As a step to make `ColumnVector` public, the `ColumnarRow` returned by `ColumnVector#getStruct` should be immutable. However we do need the mutability of `ColumnaRow` for the fast vectorized hashmap in hash aggregate. To solve this, this PR introduces a `MutableColumnarRow` for this use case. ## How was this patch tested? existing test. Author: Wenchen Fan <wenchen@databricks.com> Closes #19847 from cloud-fan/mutable-row.
This commit is contained in:
parent
92cfbeeb5c
commit
444a2bbb67
|
@ -16,8 +16,6 @@
|
|||
*/
|
||||
package org.apache.spark.sql.execution.vectorized;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow;
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
|
||||
import org.apache.spark.sql.catalyst.util.MapData;
|
||||
|
@ -32,17 +30,10 @@ import org.apache.spark.unsafe.types.UTF8String;
|
|||
public final class ColumnarRow extends InternalRow {
|
||||
protected int rowId;
|
||||
private final ColumnVector[] columns;
|
||||
private final WritableColumnVector[] writableColumns;
|
||||
|
||||
// Ctor used if this is a struct.
|
||||
ColumnarRow(ColumnVector[] columns) {
|
||||
this.columns = columns;
|
||||
this.writableColumns = new WritableColumnVector[this.columns.length];
|
||||
for (int i = 0; i < this.columns.length; i++) {
|
||||
if (this.columns[i] instanceof WritableColumnVector) {
|
||||
this.writableColumns[i] = (WritableColumnVector) this.columns[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public ColumnVector[] columns() { return columns; }
|
||||
|
@ -205,97 +196,8 @@ public final class ColumnarRow extends InternalRow {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void update(int ordinal, Object value) {
|
||||
if (value == null) {
|
||||
setNullAt(ordinal);
|
||||
} else {
|
||||
DataType dt = columns[ordinal].dataType();
|
||||
if (dt instanceof BooleanType) {
|
||||
setBoolean(ordinal, (boolean) value);
|
||||
} else if (dt instanceof IntegerType) {
|
||||
setInt(ordinal, (int) value);
|
||||
} else if (dt instanceof ShortType) {
|
||||
setShort(ordinal, (short) value);
|
||||
} else if (dt instanceof LongType) {
|
||||
setLong(ordinal, (long) value);
|
||||
} else if (dt instanceof FloatType) {
|
||||
setFloat(ordinal, (float) value);
|
||||
} else if (dt instanceof DoubleType) {
|
||||
setDouble(ordinal, (double) value);
|
||||
} else if (dt instanceof DecimalType) {
|
||||
DecimalType t = (DecimalType) dt;
|
||||
setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()),
|
||||
t.precision());
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Datatype not supported " + dt);
|
||||
}
|
||||
}
|
||||
}
|
||||
public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); }
|
||||
|
||||
@Override
|
||||
public void setNullAt(int ordinal) {
|
||||
getWritableColumn(ordinal).putNull(rowId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBoolean(int ordinal, boolean value) {
|
||||
WritableColumnVector column = getWritableColumn(ordinal);
|
||||
column.putNotNull(rowId);
|
||||
column.putBoolean(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setByte(int ordinal, byte value) {
|
||||
WritableColumnVector column = getWritableColumn(ordinal);
|
||||
column.putNotNull(rowId);
|
||||
column.putByte(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setShort(int ordinal, short value) {
|
||||
WritableColumnVector column = getWritableColumn(ordinal);
|
||||
column.putNotNull(rowId);
|
||||
column.putShort(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setInt(int ordinal, int value) {
|
||||
WritableColumnVector column = getWritableColumn(ordinal);
|
||||
column.putNotNull(rowId);
|
||||
column.putInt(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLong(int ordinal, long value) {
|
||||
WritableColumnVector column = getWritableColumn(ordinal);
|
||||
column.putNotNull(rowId);
|
||||
column.putLong(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setFloat(int ordinal, float value) {
|
||||
WritableColumnVector column = getWritableColumn(ordinal);
|
||||
column.putNotNull(rowId);
|
||||
column.putFloat(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDouble(int ordinal, double value) {
|
||||
WritableColumnVector column = getWritableColumn(ordinal);
|
||||
column.putNotNull(rowId);
|
||||
column.putDouble(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDecimal(int ordinal, Decimal value, int precision) {
|
||||
WritableColumnVector column = getWritableColumn(ordinal);
|
||||
column.putNotNull(rowId);
|
||||
column.putDecimal(rowId, value, precision);
|
||||
}
|
||||
|
||||
private WritableColumnVector getWritableColumn(int ordinal) {
|
||||
WritableColumnVector column = writableColumns[ordinal];
|
||||
assert (!column.isConstant);
|
||||
return column;
|
||||
}
|
||||
public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); }
|
||||
}
|
||||
|
|
|
@ -0,0 +1,278 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.vectorized;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow;
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
|
||||
import org.apache.spark.sql.catalyst.util.MapData;
|
||||
import org.apache.spark.sql.types.*;
|
||||
import org.apache.spark.unsafe.types.CalendarInterval;
|
||||
import org.apache.spark.unsafe.types.UTF8String;
|
||||
|
||||
/**
|
||||
* A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash
|
||||
* aggregate.
|
||||
*
|
||||
* Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to
|
||||
* avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes.
|
||||
*/
|
||||
public final class MutableColumnarRow extends InternalRow {
|
||||
public int rowId;
|
||||
private final WritableColumnVector[] columns;
|
||||
|
||||
public MutableColumnarRow(WritableColumnVector[] columns) {
|
||||
this.columns = columns;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numFields() { return columns.length; }
|
||||
|
||||
@Override
|
||||
public InternalRow copy() {
|
||||
GenericInternalRow row = new GenericInternalRow(columns.length);
|
||||
for (int i = 0; i < numFields(); i++) {
|
||||
if (isNullAt(i)) {
|
||||
row.setNullAt(i);
|
||||
} else {
|
||||
DataType dt = columns[i].dataType();
|
||||
if (dt instanceof BooleanType) {
|
||||
row.setBoolean(i, getBoolean(i));
|
||||
} else if (dt instanceof ByteType) {
|
||||
row.setByte(i, getByte(i));
|
||||
} else if (dt instanceof ShortType) {
|
||||
row.setShort(i, getShort(i));
|
||||
} else if (dt instanceof IntegerType) {
|
||||
row.setInt(i, getInt(i));
|
||||
} else if (dt instanceof LongType) {
|
||||
row.setLong(i, getLong(i));
|
||||
} else if (dt instanceof FloatType) {
|
||||
row.setFloat(i, getFloat(i));
|
||||
} else if (dt instanceof DoubleType) {
|
||||
row.setDouble(i, getDouble(i));
|
||||
} else if (dt instanceof StringType) {
|
||||
row.update(i, getUTF8String(i).copy());
|
||||
} else if (dt instanceof BinaryType) {
|
||||
row.update(i, getBinary(i));
|
||||
} else if (dt instanceof DecimalType) {
|
||||
DecimalType t = (DecimalType)dt;
|
||||
row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision());
|
||||
} else if (dt instanceof DateType) {
|
||||
row.setInt(i, getInt(i));
|
||||
} else if (dt instanceof TimestampType) {
|
||||
row.setLong(i, getLong(i));
|
||||
} else {
|
||||
throw new RuntimeException("Not implemented. " + dt);
|
||||
}
|
||||
}
|
||||
}
|
||||
return row;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean anyNull() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); }
|
||||
|
||||
@Override
|
||||
public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
|
||||
|
||||
@Override
|
||||
public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
|
||||
|
||||
@Override
|
||||
public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
|
||||
|
||||
@Override
|
||||
public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
|
||||
|
||||
@Override
|
||||
public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
|
||||
|
||||
@Override
|
||||
public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
|
||||
|
||||
@Override
|
||||
public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
|
||||
|
||||
@Override
|
||||
public Decimal getDecimal(int ordinal, int precision, int scale) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
return columns[ordinal].getDecimal(rowId, precision, scale);
|
||||
}
|
||||
|
||||
@Override
|
||||
public UTF8String getUTF8String(int ordinal) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
return columns[ordinal].getUTF8String(rowId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getBinary(int ordinal) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
return columns[ordinal].getBinary(rowId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CalendarInterval getInterval(int ordinal) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
|
||||
final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
|
||||
return new CalendarInterval(months, microseconds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ColumnarRow getStruct(int ordinal, int numFields) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
return columns[ordinal].getStruct(rowId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ColumnarArray getArray(int ordinal) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
return columns[ordinal].getArray(rowId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MapData getMap(int ordinal) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object get(int ordinal, DataType dataType) {
|
||||
if (dataType instanceof BooleanType) {
|
||||
return getBoolean(ordinal);
|
||||
} else if (dataType instanceof ByteType) {
|
||||
return getByte(ordinal);
|
||||
} else if (dataType instanceof ShortType) {
|
||||
return getShort(ordinal);
|
||||
} else if (dataType instanceof IntegerType) {
|
||||
return getInt(ordinal);
|
||||
} else if (dataType instanceof LongType) {
|
||||
return getLong(ordinal);
|
||||
} else if (dataType instanceof FloatType) {
|
||||
return getFloat(ordinal);
|
||||
} else if (dataType instanceof DoubleType) {
|
||||
return getDouble(ordinal);
|
||||
} else if (dataType instanceof StringType) {
|
||||
return getUTF8String(ordinal);
|
||||
} else if (dataType instanceof BinaryType) {
|
||||
return getBinary(ordinal);
|
||||
} else if (dataType instanceof DecimalType) {
|
||||
DecimalType t = (DecimalType) dataType;
|
||||
return getDecimal(ordinal, t.precision(), t.scale());
|
||||
} else if (dataType instanceof DateType) {
|
||||
return getInt(ordinal);
|
||||
} else if (dataType instanceof TimestampType) {
|
||||
return getLong(ordinal);
|
||||
} else if (dataType instanceof ArrayType) {
|
||||
return getArray(ordinal);
|
||||
} else if (dataType instanceof StructType) {
|
||||
return getStruct(ordinal, ((StructType)dataType).fields().length);
|
||||
} else if (dataType instanceof MapType) {
|
||||
return getMap(ordinal);
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Datatype not supported " + dataType);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void update(int ordinal, Object value) {
|
||||
if (value == null) {
|
||||
setNullAt(ordinal);
|
||||
} else {
|
||||
DataType dt = columns[ordinal].dataType();
|
||||
if (dt instanceof BooleanType) {
|
||||
setBoolean(ordinal, (boolean) value);
|
||||
} else if (dt instanceof IntegerType) {
|
||||
setInt(ordinal, (int) value);
|
||||
} else if (dt instanceof ShortType) {
|
||||
setShort(ordinal, (short) value);
|
||||
} else if (dt instanceof LongType) {
|
||||
setLong(ordinal, (long) value);
|
||||
} else if (dt instanceof FloatType) {
|
||||
setFloat(ordinal, (float) value);
|
||||
} else if (dt instanceof DoubleType) {
|
||||
setDouble(ordinal, (double) value);
|
||||
} else if (dt instanceof DecimalType) {
|
||||
DecimalType t = (DecimalType) dt;
|
||||
Decimal d = Decimal.apply((BigDecimal) value, t.precision(), t.scale());
|
||||
setDecimal(ordinal, d, t.precision());
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Datatype not supported " + dt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNullAt(int ordinal) {
|
||||
columns[ordinal].putNull(rowId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBoolean(int ordinal, boolean value) {
|
||||
columns[ordinal].putNotNull(rowId);
|
||||
columns[ordinal].putBoolean(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setByte(int ordinal, byte value) {
|
||||
columns[ordinal].putNotNull(rowId);
|
||||
columns[ordinal].putByte(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setShort(int ordinal, short value) {
|
||||
columns[ordinal].putNotNull(rowId);
|
||||
columns[ordinal].putShort(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setInt(int ordinal, int value) {
|
||||
columns[ordinal].putNotNull(rowId);
|
||||
columns[ordinal].putInt(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLong(int ordinal, long value) {
|
||||
columns[ordinal].putNotNull(rowId);
|
||||
columns[ordinal].putLong(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setFloat(int ordinal, float value) {
|
||||
columns[ordinal].putNotNull(rowId);
|
||||
columns[ordinal].putFloat(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDouble(int ordinal, double value) {
|
||||
columns[ordinal].putNotNull(rowId);
|
||||
columns[ordinal].putDouble(rowId, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDecimal(int ordinal, Decimal value, int precision) {
|
||||
columns[ordinal].putNotNull(rowId);
|
||||
columns[ordinal].putDecimal(rowId, value, precision);
|
||||
}
|
||||
}
|
|
@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
|
|||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
|
||||
import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
|
||||
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
|
||||
import org.apache.spark.unsafe.KVIterator
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -894,7 +895,7 @@ case class HashAggregateExec(
|
|||
${
|
||||
if (isVectorizedHashMapEnabled) {
|
||||
s"""
|
||||
| org.apache.spark.sql.execution.vectorized.ColumnarRow $fastRowBuffer = null;
|
||||
| ${classOf[MutableColumnarRow].getName} $fastRowBuffer = null;
|
||||
""".stripMargin
|
||||
} else {
|
||||
s"""
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate
|
|||
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
|
||||
import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -76,10 +77,9 @@ class VectorizedHashMapGenerator(
|
|||
}.mkString("\n").concat(";")
|
||||
|
||||
s"""
|
||||
| private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] batchVectors;
|
||||
| private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] bufferVectors;
|
||||
| private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
|
||||
| private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch;
|
||||
| private ${classOf[OnHeapColumnVector].getName}[] vectors;
|
||||
| private ${classOf[ColumnarBatch].getName} batch;
|
||||
| private ${classOf[MutableColumnarRow].getName} aggBufferRow;
|
||||
| private int[] buckets;
|
||||
| private int capacity = 1 << 16;
|
||||
| private double loadFactor = 0.5;
|
||||
|
@ -91,19 +91,16 @@ class VectorizedHashMapGenerator(
|
|||
| $generatedAggBufferSchema
|
||||
|
|
||||
| public $generatedClassName() {
|
||||
| batchVectors = org.apache.spark.sql.execution.vectorized
|
||||
| .OnHeapColumnVector.allocateColumns(capacity, schema);
|
||||
| batch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch(
|
||||
| schema, batchVectors, capacity);
|
||||
| vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema);
|
||||
| batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity);
|
||||
|
|
||||
| bufferVectors = new org.apache.spark.sql.execution.vectorized
|
||||
| .OnHeapColumnVector[aggregateBufferSchema.fields().length];
|
||||
| // Generates a projection to return the aggregate buffer only.
|
||||
| ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors =
|
||||
| new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length];
|
||||
| for (int i = 0; i < aggregateBufferSchema.fields().length; i++) {
|
||||
| bufferVectors[i] = batchVectors[i + ${groupingKeys.length}];
|
||||
| aggBufferVectors[i] = vectors[i + ${groupingKeys.length}];
|
||||
| }
|
||||
| // TODO: Possibly generate this projection in HashAggregate directly
|
||||
| aggregateBufferBatch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch(
|
||||
| aggregateBufferSchema, bufferVectors, capacity);
|
||||
| aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors);
|
||||
|
|
||||
| buckets = new int[numBuckets];
|
||||
| java.util.Arrays.fill(buckets, -1);
|
||||
|
@ -114,13 +111,13 @@ class VectorizedHashMapGenerator(
|
|||
|
||||
/**
|
||||
* Generates a method that returns true if the group-by keys exist at a given index in the
|
||||
* associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
|
||||
* have 2 long group-by keys, the generated function would be of the form:
|
||||
* associated [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance,
|
||||
* if we have 2 long group-by keys, the generated function would be of the form:
|
||||
*
|
||||
* {{{
|
||||
* private boolean equals(int idx, long agg_key, long agg_key1) {
|
||||
* return batchVectors[0].getLong(buckets[idx]) == agg_key &&
|
||||
* batchVectors[1].getLong(buckets[idx]) == agg_key1;
|
||||
* return vectors[0].getLong(buckets[idx]) == agg_key &&
|
||||
* vectors[1].getLong(buckets[idx]) == agg_key1;
|
||||
* }
|
||||
* }}}
|
||||
*/
|
||||
|
@ -128,7 +125,7 @@ class VectorizedHashMapGenerator(
|
|||
|
||||
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
|
||||
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
|
||||
s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"batchVectors[$ordinal]", "buckets[idx]",
|
||||
s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]",
|
||||
key.dataType), key.name)})"""
|
||||
}.mkString(" && ")
|
||||
}
|
||||
|
@ -141,29 +138,35 @@ class VectorizedHashMapGenerator(
|
|||
}
|
||||
|
||||
/**
|
||||
* Generates a method that returns a mutable
|
||||
* [[org.apache.spark.sql.execution.vectorized.ColumnarRow]] which keeps track of the
|
||||
* Generates a method that returns a
|
||||
* [[org.apache.spark.sql.execution.vectorized.MutableColumnarRow]] which keeps track of the
|
||||
* aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
|
||||
* generated method adds the corresponding row in the associated
|
||||
* [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
|
||||
* [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance, if we
|
||||
* have 2 long group-by keys, the generated function would be of the form:
|
||||
*
|
||||
* {{{
|
||||
* public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(
|
||||
* long agg_key, long agg_key1) {
|
||||
* public MutableColumnarRow findOrInsert(long agg_key, long agg_key1) {
|
||||
* long h = hash(agg_key, agg_key1);
|
||||
* int step = 0;
|
||||
* int idx = (int) h & (numBuckets - 1);
|
||||
* while (step < maxSteps) {
|
||||
* // Return bucket index if it's either an empty slot or already contains the key
|
||||
* if (buckets[idx] == -1) {
|
||||
* batchVectors[0].putLong(numRows, agg_key);
|
||||
* batchVectors[1].putLong(numRows, agg_key1);
|
||||
* batchVectors[2].putLong(numRows, 0);
|
||||
* buckets[idx] = numRows++;
|
||||
* return batch.getRow(buckets[idx]);
|
||||
* if (numRows < capacity) {
|
||||
* vectors[0].putLong(numRows, agg_key);
|
||||
* vectors[1].putLong(numRows, agg_key1);
|
||||
* vectors[2].putLong(numRows, 0);
|
||||
* buckets[idx] = numRows++;
|
||||
* aggBufferRow.rowId = numRows;
|
||||
* return aggBufferRow;
|
||||
* } else {
|
||||
* // No more space
|
||||
* return null;
|
||||
* }
|
||||
* } else if (equals(idx, agg_key, agg_key1)) {
|
||||
* return batch.getRow(buckets[idx]);
|
||||
* aggBufferRow.rowId = buckets[idx];
|
||||
* return aggBufferRow;
|
||||
* }
|
||||
* idx = (idx + 1) & (numBuckets - 1);
|
||||
* step++;
|
||||
|
@ -177,20 +180,19 @@ class VectorizedHashMapGenerator(
|
|||
|
||||
def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = {
|
||||
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
|
||||
ctx.setValue(s"batchVectors[$ordinal]", "numRows", key.dataType, key.name)
|
||||
ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name)
|
||||
}
|
||||
}
|
||||
|
||||
def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = {
|
||||
bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
|
||||
ctx.updateColumn(s"batchVectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
|
||||
ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
|
||||
buffVars(ordinal), nullable = true)
|
||||
}
|
||||
}
|
||||
|
||||
s"""
|
||||
|public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(${
|
||||
groupingKeySignature}) {
|
||||
|public ${classOf[MutableColumnarRow].getName} findOrInsert($groupingKeySignature) {
|
||||
| long h = hash(${groupingKeys.map(_.name).mkString(", ")});
|
||||
| int step = 0;
|
||||
| int idx = (int) h & (numBuckets - 1);
|
||||
|
@ -208,15 +210,15 @@ class VectorizedHashMapGenerator(
|
|||
| ${genCodeToSetAggBuffers(bufferValues).mkString("\n")}
|
||||
|
|
||||
| buckets[idx] = numRows++;
|
||||
| batch.setNumRows(numRows);
|
||||
| aggregateBufferBatch.setNumRows(numRows);
|
||||
| return aggregateBufferBatch.getRow(buckets[idx]);
|
||||
| aggBufferRow.rowId = buckets[idx];
|
||||
| return aggBufferRow;
|
||||
| } else {
|
||||
| // No more space
|
||||
| return null;
|
||||
| }
|
||||
| } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) {
|
||||
| return aggregateBufferBatch.getRow(buckets[idx]);
|
||||
| aggBufferRow.rowId = buckets[idx];
|
||||
| return aggBufferRow;
|
||||
| }
|
||||
| idx = (idx + 1) & (numBuckets - 1);
|
||||
| step++;
|
||||
|
@ -229,8 +231,8 @@ class VectorizedHashMapGenerator(
|
|||
|
||||
protected def generateRowIterator(): String = {
|
||||
s"""
|
||||
|public java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarRow>
|
||||
| rowIterator() {
|
||||
|public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() {
|
||||
| batch.setNumRows(numRows);
|
||||
| return batch.rowIterator();
|
||||
|}
|
||||
""".stripMargin
|
||||
|
|
|
@ -163,6 +163,18 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
|
|||
}
|
||||
}
|
||||
|
||||
testVectors("mutable ColumnarRow", 10, IntegerType) { testVector =>
|
||||
val mutableRow = new MutableColumnarRow(Array(testVector))
|
||||
(0 until 10).foreach { i =>
|
||||
mutableRow.rowId = i
|
||||
mutableRow.setInt(0, 10 - i)
|
||||
}
|
||||
(0 until 10).foreach { i =>
|
||||
mutableRow.rowId = i
|
||||
assert(mutableRow.getInt(0) === (10 - i))
|
||||
}
|
||||
}
|
||||
|
||||
val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true)
|
||||
testVectors("array", 10, arrayType) { testVector =>
|
||||
|
||||
|
|
|
@ -1129,29 +1129,6 @@ class ColumnarBatchSuite extends SparkFunSuite {
|
|||
testRandomRows(false, 30)
|
||||
}
|
||||
|
||||
test("mutable ColumnarBatch rows") {
|
||||
val NUM_ITERS = 10
|
||||
val types = Array(
|
||||
BooleanType, FloatType, DoubleType, IntegerType, LongType, ShortType,
|
||||
DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal,
|
||||
DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2),
|
||||
new DecimalType(12, 2), new DecimalType(30, 10))
|
||||
for (i <- 0 to NUM_ITERS) {
|
||||
val random = new Random(System.nanoTime())
|
||||
val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types)
|
||||
val oldRow = RandomDataGenerator.randomRow(random, schema)
|
||||
val newRow = RandomDataGenerator.randomRow(random, schema)
|
||||
|
||||
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
|
||||
val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava)
|
||||
val columnarBatchRow = batch.getRow(0)
|
||||
newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1))
|
||||
compareStruct(schema, columnarBatchRow, newRow, 0)
|
||||
batch.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("exceeding maximum capacity should throw an error") {
|
||||
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
|
||||
val column = allocate(1, ByteType, memMode)
|
||||
|
|
Loading…
Reference in a new issue