[SPARK-22003][SQL] support array column in vectorized reader with UDF
## What changes were proposed in this pull request? The UDF needs to deserialize the `UnsafeRow`. When the column type is Array, the `get` method from the `ColumnVector`, which is used by the vectorized reader, is called, but this method is not implemented. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu <fengliu@databricks.com> Closes #19230 from liufengdb/fix_array_open.
This commit is contained in:
parent
894a7561de
commit
3b049abf10
|
@ -100,72 +100,16 @@ public abstract class ColumnVector implements AutoCloseable {
|
|||
public Object[] array() {
|
||||
DataType dt = data.dataType();
|
||||
Object[] list = new Object[length];
|
||||
|
||||
if (dt instanceof BooleanType) {
|
||||
try {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (!data.isNullAt(offset + i)) {
|
||||
list[i] = data.getBoolean(offset + i);
|
||||
list[i] = get(i, dt);
|
||||
}
|
||||
}
|
||||
} else if (dt instanceof ByteType) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (!data.isNullAt(offset + i)) {
|
||||
list[i] = data.getByte(offset + i);
|
||||
}
|
||||
}
|
||||
} else if (dt instanceof ShortType) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (!data.isNullAt(offset + i)) {
|
||||
list[i] = data.getShort(offset + i);
|
||||
}
|
||||
}
|
||||
} else if (dt instanceof IntegerType) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (!data.isNullAt(offset + i)) {
|
||||
list[i] = data.getInt(offset + i);
|
||||
}
|
||||
}
|
||||
} else if (dt instanceof FloatType) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (!data.isNullAt(offset + i)) {
|
||||
list[i] = data.getFloat(offset + i);
|
||||
}
|
||||
}
|
||||
} else if (dt instanceof DoubleType) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (!data.isNullAt(offset + i)) {
|
||||
list[i] = data.getDouble(offset + i);
|
||||
}
|
||||
}
|
||||
} else if (dt instanceof LongType) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (!data.isNullAt(offset + i)) {
|
||||
list[i] = data.getLong(offset + i);
|
||||
}
|
||||
}
|
||||
} else if (dt instanceof DecimalType) {
|
||||
DecimalType decType = (DecimalType)dt;
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (!data.isNullAt(offset + i)) {
|
||||
list[i] = getDecimal(i, decType.precision(), decType.scale());
|
||||
}
|
||||
}
|
||||
} else if (dt instanceof StringType) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (!data.isNullAt(offset + i)) {
|
||||
list[i] = getUTF8String(i).toString();
|
||||
}
|
||||
}
|
||||
} else if (dt instanceof CalendarIntervalType) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (!data.isNullAt(offset + i)) {
|
||||
list[i] = getInterval(i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Type " + dt);
|
||||
return list;
|
||||
} catch(Exception e) {
|
||||
throw new RuntimeException("Could not get the array", e);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -237,7 +181,42 @@ public abstract class ColumnVector implements AutoCloseable {
|
|||
|
||||
@Override
|
||||
public Object get(int ordinal, DataType dataType) {
|
||||
throw new UnsupportedOperationException();
|
||||
if (dataType instanceof BooleanType) {
|
||||
return getBoolean(ordinal);
|
||||
} else if (dataType instanceof ByteType) {
|
||||
return getByte(ordinal);
|
||||
} else if (dataType instanceof ShortType) {
|
||||
return getShort(ordinal);
|
||||
} else if (dataType instanceof IntegerType) {
|
||||
return getInt(ordinal);
|
||||
} else if (dataType instanceof LongType) {
|
||||
return getLong(ordinal);
|
||||
} else if (dataType instanceof FloatType) {
|
||||
return getFloat(ordinal);
|
||||
} else if (dataType instanceof DoubleType) {
|
||||
return getDouble(ordinal);
|
||||
} else if (dataType instanceof StringType) {
|
||||
return getUTF8String(ordinal);
|
||||
} else if (dataType instanceof BinaryType) {
|
||||
return getBinary(ordinal);
|
||||
} else if (dataType instanceof DecimalType) {
|
||||
DecimalType t = (DecimalType) dataType;
|
||||
return getDecimal(ordinal, t.precision(), t.scale());
|
||||
} else if (dataType instanceof DateType) {
|
||||
return getInt(ordinal);
|
||||
} else if (dataType instanceof TimestampType) {
|
||||
return getLong(ordinal);
|
||||
} else if (dataType instanceof ArrayType) {
|
||||
return getArray(ordinal);
|
||||
} else if (dataType instanceof StructType) {
|
||||
return getStruct(ordinal, ((StructType)dataType).fields().length);
|
||||
} else if (dataType instanceof MapType) {
|
||||
return getMap(ordinal);
|
||||
} else if (dataType instanceof CalendarIntervalType) {
|
||||
return getInterval(ordinal);
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Datatype not supported " + dataType);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.vectorized
|
||||
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.util.ArrayData
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
|
||||
|
||||
var testVector: WritableColumnVector = _
|
||||
|
||||
private def allocate(capacity: Int, dt: DataType): WritableColumnVector = {
|
||||
new OnHeapColumnVector(capacity, dt)
|
||||
}
|
||||
|
||||
override def afterEach(): Unit = {
|
||||
testVector.close()
|
||||
}
|
||||
|
||||
test("boolean") {
|
||||
testVector = allocate(10, BooleanType)
|
||||
(0 until 10).foreach { i =>
|
||||
testVector.appendBoolean(i % 2 == 0)
|
||||
}
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
(0 until 10).foreach { i =>
|
||||
assert(array.get(i, BooleanType) === (i % 2 == 0))
|
||||
}
|
||||
}
|
||||
|
||||
test("byte") {
|
||||
testVector = allocate(10, ByteType)
|
||||
(0 until 10).foreach { i =>
|
||||
testVector.appendByte(i.toByte)
|
||||
}
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
(0 until 10).foreach { i =>
|
||||
assert(array.get(i, ByteType) === (i.toByte))
|
||||
}
|
||||
}
|
||||
|
||||
test("short") {
|
||||
testVector = allocate(10, ShortType)
|
||||
(0 until 10).foreach { i =>
|
||||
testVector.appendShort(i.toShort)
|
||||
}
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
(0 until 10).foreach { i =>
|
||||
assert(array.get(i, ShortType) === (i.toShort))
|
||||
}
|
||||
}
|
||||
|
||||
test("int") {
|
||||
testVector = allocate(10, IntegerType)
|
||||
(0 until 10).foreach { i =>
|
||||
testVector.appendInt(i)
|
||||
}
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
(0 until 10).foreach { i =>
|
||||
assert(array.get(i, IntegerType) === i)
|
||||
}
|
||||
}
|
||||
|
||||
test("long") {
|
||||
testVector = allocate(10, LongType)
|
||||
(0 until 10).foreach { i =>
|
||||
testVector.appendLong(i)
|
||||
}
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
(0 until 10).foreach { i =>
|
||||
assert(array.get(i, LongType) === i)
|
||||
}
|
||||
}
|
||||
|
||||
test("float") {
|
||||
testVector = allocate(10, FloatType)
|
||||
(0 until 10).foreach { i =>
|
||||
testVector.appendFloat(i.toFloat)
|
||||
}
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
(0 until 10).foreach { i =>
|
||||
assert(array.get(i, FloatType) === i.toFloat)
|
||||
}
|
||||
}
|
||||
|
||||
test("double") {
|
||||
testVector = allocate(10, DoubleType)
|
||||
(0 until 10).foreach { i =>
|
||||
testVector.appendDouble(i.toDouble)
|
||||
}
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
(0 until 10).foreach { i =>
|
||||
assert(array.get(i, DoubleType) === i.toDouble)
|
||||
}
|
||||
}
|
||||
|
||||
test("string") {
|
||||
testVector = allocate(10, StringType)
|
||||
(0 until 10).map { i =>
|
||||
val utf8 = s"str$i".getBytes("utf8")
|
||||
testVector.appendByteArray(utf8, 0, utf8.length)
|
||||
}
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
(0 until 10).foreach { i =>
|
||||
assert(array.get(i, StringType) === UTF8String.fromString(s"str$i"))
|
||||
}
|
||||
}
|
||||
|
||||
test("binary") {
|
||||
testVector = allocate(10, BinaryType)
|
||||
(0 until 10).map { i =>
|
||||
val utf8 = s"str$i".getBytes("utf8")
|
||||
testVector.appendByteArray(utf8, 0, utf8.length)
|
||||
}
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
(0 until 10).foreach { i =>
|
||||
val utf8 = s"str$i".getBytes("utf8")
|
||||
assert(array.get(i, BinaryType) === utf8)
|
||||
}
|
||||
}
|
||||
|
||||
test("array") {
|
||||
val arrayType = ArrayType(IntegerType, true)
|
||||
testVector = allocate(10, arrayType)
|
||||
|
||||
val data = testVector.arrayData()
|
||||
var i = 0
|
||||
while (i < 6) {
|
||||
data.putInt(i, i)
|
||||
i += 1
|
||||
}
|
||||
|
||||
// Populate it with arrays [0], [1, 2], [], [3, 4, 5]
|
||||
testVector.putArray(0, 0, 1)
|
||||
testVector.putArray(1, 1, 2)
|
||||
testVector.putArray(2, 3, 0)
|
||||
testVector.putArray(3, 3, 3)
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
assert(array.get(0, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(0))
|
||||
assert(array.get(1, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(1, 2))
|
||||
assert(array.get(2, arrayType).asInstanceOf[ArrayData].toIntArray() === Array.empty[Int])
|
||||
assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5))
|
||||
}
|
||||
|
||||
test("struct") {
|
||||
val schema = new StructType().add("int", IntegerType).add("double", DoubleType)
|
||||
testVector = allocate(10, schema)
|
||||
val c1 = testVector.getChildColumn(0)
|
||||
val c2 = testVector.getChildColumn(1)
|
||||
c1.putInt(0, 123)
|
||||
c2.putDouble(0, 3.45)
|
||||
c1.putInt(1, 456)
|
||||
c2.putDouble(1, 5.67)
|
||||
|
||||
val array = new ColumnVector.Array(testVector)
|
||||
|
||||
assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123)
|
||||
assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45)
|
||||
assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456)
|
||||
assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue