[SPARK-22703][SQL] make ColumnarRow an immutable view

## What changes were proposed in this pull request?

Similar to https://github.com/apache/spark/pull/19842 , we should also make `ColumnarRow` an immutable view, and move forward to make `ColumnVector` public.

## How was this patch tested?

Existing tests.

The performance concern should be same as https://github.com/apache/spark/pull/19842 .

Author: Wenchen Fan <wenchen@databricks.com>

Closes #19898 from cloud-fan/row-id.
This commit is contained in:
Wenchen Fan 2017-12-07 20:45:11 +08:00
parent c1e5688d1a
commit e103adf45a
12 changed files with 88 additions and 98 deletions

View file

@ -41,7 +41,7 @@ import static org.apache.spark.sql.types.DataTypes.LongType;
public class AggregateHashMap {
private OnHeapColumnVector[] columnVectors;
private ColumnarBatch batch;
private MutableColumnarRow aggBufferRow;
private int[] buckets;
private int numBuckets;
private int numRows = 0;
@ -63,7 +63,7 @@ public class AggregateHashMap {
this.maxSteps = maxSteps;
numBuckets = (int) (capacity / loadFactor);
columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema);
batch = new ColumnarBatch(schema, columnVectors, capacity);
aggBufferRow = new MutableColumnarRow(columnVectors);
buckets = new int[numBuckets];
Arrays.fill(buckets, -1);
}
@ -72,14 +72,15 @@ public class AggregateHashMap {
this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS);
}
public ColumnarRow findOrInsert(long key) {
public MutableColumnarRow findOrInsert(long key) {
int idx = find(key);
if (idx != -1 && buckets[idx] == -1) {
columnVectors[0].putLong(numRows, key);
columnVectors[1].putLong(numRows, 0);
buckets[idx] = numRows++;
}
return batch.getRow(buckets[idx]);
aggBufferRow.rowId = buckets[idx];
return aggBufferRow;
}
@VisibleForTesting

View file

@ -323,7 +323,6 @@ public final class ArrowColumnVector extends ColumnVector {
for (int i = 0; i < childColumns.length; ++i) {
childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i));
}
resultStruct = new ColumnarRow(childColumns);
} else {
throw new UnsupportedOperationException();
}

View file

@ -157,18 +157,16 @@ public abstract class ColumnVector implements AutoCloseable {
/**
* Returns a utility object to get structs.
*/
public ColumnarRow getStruct(int rowId) {
resultStruct.rowId = rowId;
return resultStruct;
public final ColumnarRow getStruct(int rowId) {
return new ColumnarRow(this, rowId);
}
/**
* Returns a utility object to get structs.
* provided to keep API compatibility with InternalRow for code generation
*/
public ColumnarRow getStruct(int rowId, int size) {
resultStruct.rowId = rowId;
return resultStruct;
public final ColumnarRow getStruct(int rowId, int size) {
return getStruct(rowId);
}
/**
@ -216,11 +214,6 @@ public abstract class ColumnVector implements AutoCloseable {
*/
protected DataType type;
/**
* Reusable Struct holder for getStruct().
*/
protected ColumnarRow resultStruct;
/**
* The Dictionary for this column.
*

View file

@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.vectorized;
import java.util.*;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructType;
/**
@ -40,10 +41,10 @@ public final class ColumnarBatch {
private final StructType schema;
private final int capacity;
private int numRows;
final ColumnVector[] columns;
private final ColumnVector[] columns;
// Staging row returned from getRow.
final ColumnarRow row;
// Staging row returned from `getRow`.
private final MutableColumnarRow row;
/**
* Called to close all the columns in this batch. It is not valid to access the data after
@ -58,10 +59,10 @@ public final class ColumnarBatch {
/**
* Returns an iterator over the rows in this batch. This skips rows that are filtered out.
*/
public Iterator<ColumnarRow> rowIterator() {
public Iterator<InternalRow> rowIterator() {
final int maxRows = numRows;
final ColumnarRow row = new ColumnarRow(columns);
return new Iterator<ColumnarRow>() {
final MutableColumnarRow row = new MutableColumnarRow(columns);
return new Iterator<InternalRow>() {
int rowId = 0;
@Override
@ -70,7 +71,7 @@ public final class ColumnarBatch {
}
@Override
public ColumnarRow next() {
public InternalRow next() {
if (rowId >= maxRows) {
throw new NoSuchElementException();
}
@ -133,9 +134,8 @@ public final class ColumnarBatch {
/**
* Returns the row in this batch at `rowId`. Returned row is reused across calls.
*/
public ColumnarRow getRow(int rowId) {
assert(rowId >= 0);
assert(rowId < numRows);
public InternalRow getRow(int rowId) {
assert(rowId >= 0 && rowId < numRows);
row.rowId = rowId;
return row;
}
@ -144,6 +144,6 @@ public final class ColumnarBatch {
this.schema = schema;
this.columns = columns;
this.capacity = capacity;
this.row = new ColumnarRow(columns);
this.row = new MutableColumnarRow(columns);
}
}

View file

@ -28,30 +28,32 @@ import org.apache.spark.unsafe.types.UTF8String;
* to be reused, callers should copy the data out if it needs to be stored.
*/
public final class ColumnarRow extends InternalRow {
protected int rowId;
private final ColumnVector[] columns;
// The data for this row. E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`.
private final ColumnVector data;
private final int rowId;
private final int numFields;
// Ctor used if this is a struct.
ColumnarRow(ColumnVector[] columns) {
this.columns = columns;
ColumnarRow(ColumnVector data, int rowId) {
assert (data.dataType() instanceof StructType);
this.data = data;
this.rowId = rowId;
this.numFields = ((StructType) data.dataType()).size();
}
public ColumnVector[] columns() { return columns; }
@Override
public int numFields() { return columns.length; }
public int numFields() { return numFields; }
/**
* Revisit this. This is expensive. This is currently only used in test paths.
*/
@Override
public InternalRow copy() {
GenericInternalRow row = new GenericInternalRow(columns.length);
GenericInternalRow row = new GenericInternalRow(numFields);
for (int i = 0; i < numFields(); i++) {
if (isNullAt(i)) {
row.setNullAt(i);
} else {
DataType dt = columns[i].dataType();
DataType dt = data.getChildColumn(i).dataType();
if (dt instanceof BooleanType) {
row.setBoolean(i, getBoolean(i));
} else if (dt instanceof ByteType) {
@ -91,65 +93,65 @@ public final class ColumnarRow extends InternalRow {
}
@Override
public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); }
public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); }
@Override
public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); }
@Override
public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); }
@Override
public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); }
@Override
public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); }
@Override
public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); }
@Override
public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); }
@Override
public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); }
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getDecimal(rowId, precision, scale);
if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale);
}
@Override
public UTF8String getUTF8String(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getUTF8String(rowId);
if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
return data.getChildColumn(ordinal).getUTF8String(rowId);
}
@Override
public byte[] getBinary(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getBinary(rowId);
if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
return data.getChildColumn(ordinal).getBinary(rowId);
}
@Override
public CalendarInterval getInterval(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId);
final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId);
return new CalendarInterval(months, microseconds);
}
@Override
public ColumnarRow getStruct(int ordinal, int numFields) {
if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getStruct(rowId);
if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
return data.getChildColumn(ordinal).getStruct(rowId);
}
@Override
public ColumnarArray getArray(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getArray(rowId);
if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
return data.getChildColumn(ordinal).getArray(rowId);
}
@Override

View file

@ -28,17 +28,24 @@ import org.apache.spark.unsafe.types.UTF8String;
/**
* A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash
* aggregate.
* aggregate, and {@link ColumnarBatch} to save object creation.
*
* Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to
* avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes.
*/
public final class MutableColumnarRow extends InternalRow {
public int rowId;
private final WritableColumnVector[] columns;
private final ColumnVector[] columns;
private final WritableColumnVector[] writableColumns;
public MutableColumnarRow(WritableColumnVector[] columns) {
public MutableColumnarRow(ColumnVector[] columns) {
this.columns = columns;
this.writableColumns = null;
}
public MutableColumnarRow(WritableColumnVector[] writableColumns) {
this.columns = writableColumns;
this.writableColumns = writableColumns;
}
@Override
@ -225,54 +232,54 @@ public final class MutableColumnarRow extends InternalRow {
@Override
public void setNullAt(int ordinal) {
columns[ordinal].putNull(rowId);
writableColumns[ordinal].putNull(rowId);
}
@Override
public void setBoolean(int ordinal, boolean value) {
columns[ordinal].putNotNull(rowId);
columns[ordinal].putBoolean(rowId, value);
writableColumns[ordinal].putNotNull(rowId);
writableColumns[ordinal].putBoolean(rowId, value);
}
@Override
public void setByte(int ordinal, byte value) {
columns[ordinal].putNotNull(rowId);
columns[ordinal].putByte(rowId, value);
writableColumns[ordinal].putNotNull(rowId);
writableColumns[ordinal].putByte(rowId, value);
}
@Override
public void setShort(int ordinal, short value) {
columns[ordinal].putNotNull(rowId);
columns[ordinal].putShort(rowId, value);
writableColumns[ordinal].putNotNull(rowId);
writableColumns[ordinal].putShort(rowId, value);
}
@Override
public void setInt(int ordinal, int value) {
columns[ordinal].putNotNull(rowId);
columns[ordinal].putInt(rowId, value);
writableColumns[ordinal].putNotNull(rowId);
writableColumns[ordinal].putInt(rowId, value);
}
@Override
public void setLong(int ordinal, long value) {
columns[ordinal].putNotNull(rowId);
columns[ordinal].putLong(rowId, value);
writableColumns[ordinal].putNotNull(rowId);
writableColumns[ordinal].putLong(rowId, value);
}
@Override
public void setFloat(int ordinal, float value) {
columns[ordinal].putNotNull(rowId);
columns[ordinal].putFloat(rowId, value);
writableColumns[ordinal].putNotNull(rowId);
writableColumns[ordinal].putFloat(rowId, value);
}
@Override
public void setDouble(int ordinal, double value) {
columns[ordinal].putNotNull(rowId);
columns[ordinal].putDouble(rowId, value);
writableColumns[ordinal].putNotNull(rowId);
writableColumns[ordinal].putDouble(rowId, value);
}
@Override
public void setDecimal(int ordinal, Decimal value, int precision) {
columns[ordinal].putNotNull(rowId);
columns[ordinal].putDecimal(rowId, value, precision);
writableColumns[ordinal].putNotNull(rowId);
writableColumns[ordinal].putDecimal(rowId, value, precision);
}
}

View file

@ -547,7 +547,7 @@ public final class OffHeapColumnVector extends WritableColumnVector {
} else if (type instanceof LongType || type instanceof DoubleType ||
DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) {
this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8);
} else if (resultStruct != null) {
} else if (childColumns != null) {
// Nothing to store.
} else {
throw new RuntimeException("Unhandled " + type);

View file

@ -558,7 +558,7 @@ public final class OnHeapColumnVector extends WritableColumnVector {
if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity);
doubleData = newData;
}
} else if (resultStruct != null) {
} else if (childColumns != null) {
// Nothing to store.
} else {
throw new RuntimeException("Unhandled " + type);

View file

@ -74,7 +74,6 @@ public abstract class WritableColumnVector extends ColumnVector {
dictionaryIds = null;
}
dictionary = null;
resultStruct = null;
}
public void reserve(int requiredCapacity) {
@ -673,23 +672,19 @@ public abstract class WritableColumnVector extends ColumnVector {
}
this.childColumns = new WritableColumnVector[1];
this.childColumns[0] = reserveNewColumn(childCapacity, childType);
this.resultStruct = null;
} else if (type instanceof StructType) {
StructType st = (StructType)type;
this.childColumns = new WritableColumnVector[st.fields().length];
for (int i = 0; i < childColumns.length; ++i) {
this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType());
}
this.resultStruct = new ColumnarRow(this.childColumns);
} else if (type instanceof CalendarIntervalType) {
// Two columns. Months as int. Microseconds as Long.
this.childColumns = new WritableColumnVector[2];
this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType);
this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType);
this.resultStruct = new ColumnarRow(this.childColumns);
} else {
this.childColumns = null;
this.resultStruct = null;
}
}
}

View file

@ -595,9 +595,7 @@ case class HashAggregateExec(
ctx.addMutableState(fastHashMapClassName, fastHashMapTerm,
s"$fastHashMapTerm = new $fastHashMapClassName();")
ctx.addMutableState(
s"java.util.Iterator<${classOf[ColumnarRow].getName}>",
iterTermForFastHashMap)
ctx.addMutableState(s"java.util.Iterator<InternalRow>", iterTermForFastHashMap)
} else {
val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
@ -674,7 +672,7 @@ case class HashAggregateExec(
""".stripMargin
}
// Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow
// Iterate over the aggregate rows and convert them from InternalRow to UnsafeRow
def outputFromVectorizedMap: String = {
val row = ctx.freshName("fastHashMapRow")
ctx.currentVars = null
@ -687,10 +685,9 @@ case class HashAggregateExec(
bufferSchema.toAttributes.zipWithIndex.map { case (attr, i) =>
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable)
})
val columnarRowCls = classOf[ColumnarRow].getName
s"""
|while ($iterTermForFastHashMap.hasNext()) {
| $columnarRowCls $row = ($columnarRowCls) $iterTermForFastHashMap.next();
| InternalRow $row = (InternalRow) $iterTermForFastHashMap.next();
| ${generateKeyRow.code}
| ${generateBufferRow.code}
| $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value});

View file

@ -17,9 +17,10 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector}
import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector}
import org.apache.spark.sql.types._
/**
@ -231,7 +232,7 @@ class VectorizedHashMapGenerator(
protected def generateRowIterator(): String = {
s"""
|public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() {
|public java.util.Iterator<${classOf[InternalRow].getName}> rowIterator() {
| batch.setNumRows(numRows);
| return batch.rowIterator();
|}

View file

@ -751,11 +751,6 @@ class ColumnarBatchSuite extends SparkFunSuite {
c2.putDouble(1, 5.67)
val s = column.getStruct(0)
assert(s.columns()(0).getInt(0) == 123)
assert(s.columns()(0).getInt(1) == 456)
assert(s.columns()(1).getDouble(0) == 3.45)
assert(s.columns()(1).getDouble(1) == 5.67)
assert(s.getInt(0) == 123)
assert(s.getDouble(1) == 3.45)