diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java index 98a9073227..cabb747952 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -16,8 +16,6 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; - import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; @@ -32,17 +30,10 @@ import org.apache.spark.unsafe.types.UTF8String; public final class ColumnarRow extends InternalRow { protected int rowId; private final ColumnVector[] columns; - private final WritableColumnVector[] writableColumns; // Ctor used if this is a struct. ColumnarRow(ColumnVector[] columns) { this.columns = columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } } public ColumnVector[] columns() { return columns; } @@ -205,97 +196,8 @@ public final class ColumnarRow extends InternalRow { } @Override - public void update(int ordinal, Object value) { - if (value == null) { - setNullAt(ordinal); - } else { - DataType dt = columns[ordinal].dataType(); - if (dt instanceof BooleanType) { - setBoolean(ordinal, (boolean) value); - } else if (dt instanceof IntegerType) { - setInt(ordinal, (int) value); - } else if (dt instanceof ShortType) { - setShort(ordinal, (short) value); - } else if (dt instanceof LongType) { - setLong(ordinal, (long) value); - } else if (dt instanceof FloatType) { - setFloat(ordinal, (float) value); - } else if (dt instanceof DoubleType) { - setDouble(ordinal, (double) value); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType) dt; - setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), - t.precision()); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dt); - } - } - } + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } @Override - public void setNullAt(int ordinal) { - getWritableColumn(ordinal).putNull(rowId); - } - - @Override - public void setBoolean(int ordinal, boolean value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putBoolean(rowId, value); - } - - @Override - public void setByte(int ordinal, byte value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putByte(rowId, value); - } - - @Override - public void setShort(int ordinal, short value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putShort(rowId, value); - } - - @Override - public void setInt(int ordinal, int value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putInt(rowId, value); - } - - @Override - public void setLong(int ordinal, long value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putLong(rowId, value); - } - - @Override - public void setFloat(int ordinal, float value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putFloat(rowId, value); - } - - @Override - public void setDouble(int ordinal, double value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDouble(rowId, value); - } - - @Override - public void setDecimal(int ordinal, Decimal value, int precision) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDecimal(rowId, value, precision); - } - - private WritableColumnVector getWritableColumn(int ordinal) { - WritableColumnVector column = writableColumns[ordinal]; - assert (!column.isConstant); - return column; - } + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java new file mode 100644 index 0000000000..f272cc1636 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash + * aggregate. + * + * Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to + * avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes. + */ +public final class MutableColumnarRow extends InternalRow { + public int rowId; + private final WritableColumnVector[] columns; + + public MutableColumnarRow(WritableColumnVector[] columns) { + this.columns = columns; + } + + @Override + public int numFields() { return columns.length; } + + @Override + public InternalRow copy() { + GenericInternalRow row = new GenericInternalRow(columns.length); + for (int i = 0; i < numFields(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = columns[i].dataType(); + if (dt instanceof BooleanType) { + row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof ByteType) { + row.setByte(i, getByte(i)); + } else if (dt instanceof ShortType) { + row.setShort(i, getShort(i)); + } else if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof FloatType) { + row.setFloat(i, getFloat(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else if (dt instanceof StringType) { + row.update(i, getUTF8String(i).copy()); + } else if (dt instanceof BinaryType) { + row.update(i, getBinary(i)); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType)dt; + row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); + } else if (dt instanceof DateType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof TimestampType) { + row.setLong(i, getLong(i)); + } else { + throw new RuntimeException("Not implemented. " + dt); + } + } + } + return row; + } + + @Override + public boolean anyNull() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + + @Override + public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + + @Override + public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + + @Override + public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + + @Override + public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + + @Override + public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + + @Override + public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getBinary(rowId); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getStruct(rowId); + } + + @Override + public ColumnarArray getArray(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getArray(rowId); + } + + @Override + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { + if (value == null) { + setNullAt(ordinal); + } else { + DataType dt = columns[ordinal].dataType(); + if (dt instanceof BooleanType) { + setBoolean(ordinal, (boolean) value); + } else if (dt instanceof IntegerType) { + setInt(ordinal, (int) value); + } else if (dt instanceof ShortType) { + setShort(ordinal, (short) value); + } else if (dt instanceof LongType) { + setLong(ordinal, (long) value); + } else if (dt instanceof FloatType) { + setFloat(ordinal, (float) value); + } else if (dt instanceof DoubleType) { + setDouble(ordinal, (double) value); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType) dt; + Decimal d = Decimal.apply((BigDecimal) value, t.precision(), t.scale()); + setDecimal(ordinal, d, t.precision()); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dt); + } + } + } + + @Override + public void setNullAt(int ordinal) { + columns[ordinal].putNull(rowId); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putBoolean(rowId, value); + } + + @Override + public void setByte(int ordinal, byte value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putByte(rowId, value); + } + + @Override + public void setShort(int ordinal, short value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putShort(rowId, value); + } + + @Override + public void setInt(int ordinal, int value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putInt(rowId, value); + } + + @Override + public void setLong(int ordinal, long value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putLong(rowId, value); + } + + @Override + public void setFloat(int ordinal, float value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putFloat(rowId, value); + } + + @Override + public void setDouble(int ordinal, double value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDouble(rowId, value); + } + + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDecimal(rowId, value, precision); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index dc8aecf185..913978892c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -894,7 +895,7 @@ case class HashAggregateExec( ${ if (isVectorizedHashMapEnabled) { s""" - | org.apache.spark.sql.execution.vectorized.ColumnarRow $fastRowBuffer = null; + | ${classOf[MutableColumnarRow].getName} $fastRowBuffer = null; """.stripMargin } else { s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index fd783d905b..44ba539ebf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ /** @@ -76,10 +77,9 @@ class VectorizedHashMapGenerator( }.mkString("\n").concat(";") s""" - | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] batchVectors; - | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] bufferVectors; - | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; - | private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch; + | private ${classOf[OnHeapColumnVector].getName}[] vectors; + | private ${classOf[ColumnarBatch].getName} batch; + | private ${classOf[MutableColumnarRow].getName} aggBufferRow; | private int[] buckets; | private int capacity = 1 << 16; | private double loadFactor = 0.5; @@ -91,19 +91,16 @@ class VectorizedHashMapGenerator( | $generatedAggBufferSchema | | public $generatedClassName() { - | batchVectors = org.apache.spark.sql.execution.vectorized - | .OnHeapColumnVector.allocateColumns(capacity, schema); - | batch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch( - | schema, batchVectors, capacity); + | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); + | batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity); | - | bufferVectors = new org.apache.spark.sql.execution.vectorized - | .OnHeapColumnVector[aggregateBufferSchema.fields().length]; + | // Generates a projection to return the aggregate buffer only. + | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = + | new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length]; | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) { - | bufferVectors[i] = batchVectors[i + ${groupingKeys.length}]; + | aggBufferVectors[i] = vectors[i + ${groupingKeys.length}]; | } - | // TODO: Possibly generate this projection in HashAggregate directly - | aggregateBufferBatch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch( - | aggregateBufferSchema, bufferVectors, capacity); + | aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors); | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); @@ -114,13 +111,13 @@ class VectorizedHashMapGenerator( /** * Generates a method that returns true if the group-by keys exist at a given index in the - * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we - * have 2 long group-by keys, the generated function would be of the form: + * associated [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance, + * if we have 2 long group-by keys, the generated function would be of the form: * * {{{ * private boolean equals(int idx, long agg_key, long agg_key1) { - * return batchVectors[0].getLong(buckets[idx]) == agg_key && - * batchVectors[1].getLong(buckets[idx]) == agg_key1; + * return vectors[0].getLong(buckets[idx]) == agg_key && + * vectors[1].getLong(buckets[idx]) == agg_key1; * } * }}} */ @@ -128,7 +125,7 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"batchVectors[$ordinal]", "buckets[idx]", + s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]", key.dataType), key.name)})""" }.mkString(" && ") } @@ -141,29 +138,35 @@ class VectorizedHashMapGenerator( } /** - * Generates a method that returns a mutable - * [[org.apache.spark.sql.execution.vectorized.ColumnarRow]] which keeps track of the + * Generates a method that returns a + * [[org.apache.spark.sql.execution.vectorized.MutableColumnarRow]] which keeps track of the * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the * generated method adds the corresponding row in the associated - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance, if we * have 2 long group-by keys, the generated function would be of the form: * * {{{ - * public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert( - * long agg_key, long agg_key1) { + * public MutableColumnarRow findOrInsert(long agg_key, long agg_key1) { * long h = hash(agg_key, agg_key1); * int step = 0; * int idx = (int) h & (numBuckets - 1); * while (step < maxSteps) { * // Return bucket index if it's either an empty slot or already contains the key * if (buckets[idx] == -1) { - * batchVectors[0].putLong(numRows, agg_key); - * batchVectors[1].putLong(numRows, agg_key1); - * batchVectors[2].putLong(numRows, 0); - * buckets[idx] = numRows++; - * return batch.getRow(buckets[idx]); + * if (numRows < capacity) { + * vectors[0].putLong(numRows, agg_key); + * vectors[1].putLong(numRows, agg_key1); + * vectors[2].putLong(numRows, 0); + * buckets[idx] = numRows++; + * aggBufferRow.rowId = numRows; + * return aggBufferRow; + * } else { + * // No more space + * return null; + * } * } else if (equals(idx, agg_key, agg_key1)) { - * return batch.getRow(buckets[idx]); + * aggBufferRow.rowId = buckets[idx]; + * return aggBufferRow; * } * idx = (idx + 1) & (numBuckets - 1); * step++; @@ -177,20 +180,19 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue(s"batchVectors[$ordinal]", "numRows", key.dataType, key.name) + ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn(s"batchVectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, + ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, buffVars(ordinal), nullable = true) } } s""" - |public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(${ - groupingKeySignature}) { + |public ${classOf[MutableColumnarRow].getName} findOrInsert($groupingKeySignature) { | long h = hash(${groupingKeys.map(_.name).mkString(", ")}); | int step = 0; | int idx = (int) h & (numBuckets - 1); @@ -208,15 +210,15 @@ class VectorizedHashMapGenerator( | ${genCodeToSetAggBuffers(bufferValues).mkString("\n")} | | buckets[idx] = numRows++; - | batch.setNumRows(numRows); - | aggregateBufferBatch.setNumRows(numRows); - | return aggregateBufferBatch.getRow(buckets[idx]); + | aggBufferRow.rowId = buckets[idx]; + | return aggBufferRow; | } else { | // No more space | return null; | } | } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) { - | return aggregateBufferBatch.getRow(buckets[idx]); + | aggBufferRow.rowId = buckets[idx]; + | return aggBufferRow; | } | idx = (idx + 1) & (numBuckets - 1); | step++; @@ -229,8 +231,8 @@ class VectorizedHashMapGenerator( protected def generateRowIterator(): String = { s""" - |public java.util.Iterator - | rowIterator() { + |public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() { + | batch.setNumRows(numRows); | return batch.rowIterator(); |} """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 3c76ca79f5..e28ab710f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -163,6 +163,18 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } + testVectors("mutable ColumnarRow", 10, IntegerType) { testVector => + val mutableRow = new MutableColumnarRow(Array(testVector)) + (0 until 10).foreach { i => + mutableRow.rowId = i + mutableRow.setInt(0, 10 - i) + } + (0 until 10).foreach { i => + mutableRow.rowId = i + assert(mutableRow.getInt(0) === (10 - i)) + } + } + val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true) testVectors("array", 10, arrayType) { testVector => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 80a50866aa..1b4e2bad09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1129,29 +1129,6 @@ class ColumnarBatchSuite extends SparkFunSuite { testRandomRows(false, 30) } - test("mutable ColumnarBatch rows") { - val NUM_ITERS = 10 - val types = Array( - BooleanType, FloatType, DoubleType, IntegerType, LongType, ShortType, - DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal, - DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2), - new DecimalType(12, 2), new DecimalType(30, 10)) - for (i <- 0 to NUM_ITERS) { - val random = new Random(System.nanoTime()) - val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types) - val oldRow = RandomDataGenerator.randomRow(random, schema) - val newRow = RandomDataGenerator.randomRow(random, schema) - - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava) - val columnarBatchRow = batch.getRow(0) - newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1)) - compareStruct(schema, columnarBatchRow, newRow, 0) - batch.close() - } - } - } - test("exceeding maximum capacity should throw an error") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => val column = allocate(1, ByteType, memMode)