[SPARK-35640][SQL] Refactor Parquet vectorized reader to remove duplicated code paths

### What changes were proposed in this pull request?

1. Remove duplicated code in the form of `readXXX` in `VectorizedRleValuesReader`. For instance:
```java
  public void readIntegers(
      int total,
      WritableColumnVector c,
      int rowId,
      int level,
      VectorizedValuesReader data) throws IOException {
    int left = total;
    while (left > 0) {
      if (this.currentCount == 0) this.readNextGroup();
      int n = Math.min(left, this.currentCount);
      switch (mode) {
        case RLE:
          if (currentValue == level) {
            data.readIntegers(n, c, rowId);
          } else {
            c.putNulls(rowId, n);
          }
          break;
        case PACKED:
          for (int i = 0; i < n; ++i) {
            if (currentBuffer[currentBufferIdx++] == level) {
              c.putInt(rowId + i, data.readInteger());
            } else {
              c.putNull(rowId + i);
            }
          }
          break;
      }
      rowId += n;
      left -= n;
      currentCount -= n;
    }
  }
```
and replace with:
```java
  public void readBatch(
       int total,
       int offset,
       WritableColumnVector values,
       int maxDefinitionLevel,
       VectorizedValuesReader valueReader,
       ParquetVectorUpdater updater) throws IOException {
     int left = total;
     while (left > 0) {
       if (this.currentCount == 0) this.readNextGroup();
       int n = Math.min(left, this.currentCount);
       switch (mode) {
         case RLE:
           if (currentValue == maxDefinitionLevel) {
             updater.updateBatch(n, offset, values, valueReader);
           } else {
             values.putNulls(offset, n);
           }
           break;
         case PACKED:
           for (int i = 0; i < n; ++i) {
             if (currentBuffer[currentBufferIdx++] == maxDefinitionLevel) {
               updater.update(offset + i, values, valueReader);
             } else {
               values.putNull(offset + i);
             }
           }
           break;
       }
       offset += n;
       left -= n;
       currentCount -= n;
     }
   }
```
where the `ParquetVectorUpdater` is type specific, and has different implementations under `updateBatch` and `update`. Together, this also changes code paths handling timestamp types to use the batch read API for decoding definition levels.

2. Similar to the above, this removes code duplication in `VectorizedColumnReader.decodeDictionaryIds`. Now different implementations are under `ParquetVectorUpdater.decodeSingleDictionaryId`.

### Why are the changes needed?

`VectorizedRleValuesReader` and `VectorizedColumnReader` are becoming increasingly harder to maintain, as any change touches the above logic **will need to be replicated in 20+ places**. The issue becomes even more serious when we are going to implement column index (for instance, see how the change [here](https://github.com/apache/spark/pull/32753/files#diff-a01e174e178366aadf07f64ee690d47d343b2ca416a4a2b2ea735887c22d5934R191) has to be replicated multiple times) and complex type support (in progress) for the vectorized path.

In addition, currently dictionary decoding (see `VectorizedColumnReader.decodeDictionaryIds`) and non-dictionary decoding are handled separately, and therefore the same (very complicated) branching logic based on input Spark & Parquet types have to be replicated in two places, which is another burden for code maintenance.

The original intention is for performance. However these days JIT compilers tend to be very effective on this and will inline virtual calls aggressively to eliminate the method invocation costs (see [this](https://shipilev.net/blog/2015/black-magic-method-dispatch/) and [this](http://insightfullogic.com/blog/2014/may/12/fast-and-megamorphic-what-influences-method-invoca/)). I've also done benchmarks using a modified `DataSourceReadBenchmark` and `DateTimeRebaseBenchmark` and the result is almost exact the same before and after the change. The results can be found [here](https://gist.github.com/sunchao/674afbf942ccc2370bdcfa33efb4471c), and [here's](https://github.com/sunchao/spark/tree/parquet-refactor) the source code.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing tests.

Closes #32777 from sunchao/SPARK-35640.

Authored-by: Chao Sun <sunchao@apple.com>
Signed-off-by: DB Tsai <d_tsai@apple.com>
This commit is contained in:
Chao Sun 2021-06-11 05:39:43 +00:00 committed by DB Tsai
parent 463daabd5a
commit e9ccf4a50c
5 changed files with 1144 additions and 999 deletions

View file

@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.parquet;
import org.apache.parquet.column.Dictionary;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
public interface ParquetVectorUpdater {
/**
* Read a batch of `total` values from `valuesReader` into `values`, starting from `offset`.
*
* @param total total number of values to read
* @param offset starting offset in `values`
* @param values destination values vector
* @param valuesReader reader to read values from
*/
void updateBatch(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader);
/**
* Read a single value from `valuesReader` into `values`, at `offset`.
*
* @param offset offset in `values` to put the new value
* @param values destination value vector
* @param valuesReader reader to read values from
*/
void update(int offset, WritableColumnVector values, VectorizedValuesReader valuesReader);
/**
* Process a batch of `total` values starting from `offset` in `values`, whose null slots
* should have already been filled, and fills the non-null slots using dictionary IDs from
* `dictionaryIds`, together with Parquet `dictionary`.
*
* @param total total number slots to process in `values`
* @param offset starting offset in `values`
* @param values destination value vector
* @param dictionaryIds vector storing the dictionary IDs
* @param dictionary Parquet dictionary used to decode a dictionary ID to its value
*/
default void decodeDictionaryIds(
int total,
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary) {
for (int i = offset; i < offset + total; i++) {
if (!values.isNullAt(i)) {
decodeSingleDictionaryId(i, values, dictionaryIds, dictionary);
}
}
}
/**
* Decode a single dictionary ID from `dictionaryIds` into `values` at `offset`, using
* `dictionary`.
*
* @param offset offset in `values` to put the decoded value
* @param values destination value vector
* @param dictionaryIds vector storing the dictionary IDs
* @param dictionary Parquet dictionary used to decode a dictionary ID to its value
*/
void decodeSingleDictionaryId(
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary);
}

View file

@ -18,10 +18,7 @@
package org.apache.spark.sql.execution.datasources.parquet;
import java.io.IOException;
import java.math.BigInteger;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.util.Arrays;
import org.apache.parquet.bytes.ByteBufferInputStream;
import org.apache.parquet.bytes.BytesInput;
@ -31,24 +28,14 @@ import org.apache.parquet.column.Dictionary;
import org.apache.parquet.column.Encoding;
import org.apache.parquet.column.page.*;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.IntLogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.DateLogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.catalyst.util.RebaseDateTime;
import org.apache.spark.sql.execution.datasources.DataSourceUtils;
import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;
@ -103,41 +90,15 @@ public class VectorizedColumnReader {
*/
private int pageValueCount;
/**
* Factory to get type-specific vector updater.
*/
private final ParquetVectorUpdaterFactory updaterFactory;
private final PageReader pageReader;
private final ColumnDescriptor descriptor;
private final LogicalTypeAnnotation logicalTypeAnnotation;
// The timezone conversion to apply to int96 timestamps. Null if no conversion.
private final ZoneId convertTz;
private static final ZoneId UTC = ZoneOffset.UTC;
private final String datetimeRebaseMode;
private final String int96RebaseMode;
private boolean isDecimalTypeMatched(DataType dt) {
DecimalType d = (DecimalType) dt;
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation;
// It's OK if the required decimal precision is larger than or equal to the physical decimal
// precision in the Parquet metadata, as long as the decimal scale is the same.
return decimalType.getPrecision() <= d.precision() && decimalType.getScale() == d.scale();
}
return false;
}
private boolean canReadAsIntDecimal(DataType dt) {
if (!DecimalType.is32BitDecimalType(dt)) return false;
return isDecimalTypeMatched(dt);
}
private boolean canReadAsLongDecimal(DataType dt) {
if (!DecimalType.is64BitDecimalType(dt)) return false;
return isDecimalTypeMatched(dt);
}
private boolean canReadAsBinaryDecimal(DataType dt) {
if (!DecimalType.isByteArrayDecimalType(dt)) return false;
return isDecimalTypeMatched(dt);
}
public VectorizedColumnReader(
ColumnDescriptor descriptor,
@ -148,9 +109,10 @@ public class VectorizedColumnReader {
String int96RebaseMode) throws IOException {
this.descriptor = descriptor;
this.pageReader = pageReader;
this.convertTz = convertTz;
this.logicalTypeAnnotation = logicalTypeAnnotation;
this.maxDefLevel = descriptor.getMaxDefinitionLevel();
this.updaterFactory = new ParquetVectorUpdaterFactory(
logicalTypeAnnotation, convertTz, datetimeRebaseMode, int96RebaseMode);
DictionaryPage dictionaryPage = pageReader.readDictionaryPage();
if (dictionaryPage != null) {
@ -173,7 +135,6 @@ public class VectorizedColumnReader {
this.datetimeRebaseMode = datetimeRebaseMode;
assert "LEGACY".equals(int96RebaseMode) || "EXCEPTION".equals(int96RebaseMode) ||
"CORRECTED".equals(int96RebaseMode);
this.int96RebaseMode = int96RebaseMode;
}
private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName) {
@ -184,10 +145,10 @@ public class VectorizedColumnReader {
"CORRECTED".equals(datetimeRebaseMode);
break;
case INT64:
if (isTimestampTypeMatched(TimeUnit.MICROS)) {
if (updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS)) {
isSupported = "CORRECTED".equals(datetimeRebaseMode);
} else {
isSupported = !isTimestampTypeMatched(TimeUnit.MILLIS);
isSupported = !updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
}
break;
case FLOAT:
@ -199,47 +160,14 @@ public class VectorizedColumnReader {
return isSupported;
}
static int rebaseDays(int julianDays, final boolean failIfRebase) {
if (failIfRebase) {
if (julianDays < RebaseDateTime.lastSwitchJulianDay()) {
throw DataSourceUtils.newRebaseExceptionInRead("Parquet");
} else {
return julianDays;
}
} else {
return RebaseDateTime.rebaseJulianToGregorianDays(julianDays);
}
}
private static long rebaseTimestamp(
long julianMicros,
final boolean failIfRebase,
final String format) {
if (failIfRebase) {
if (julianMicros < RebaseDateTime.lastSwitchJulianTs()) {
throw DataSourceUtils.newRebaseExceptionInRead(format);
} else {
return julianMicros;
}
} else {
return RebaseDateTime.rebaseJulianToGregorianMicros(julianMicros);
}
}
static long rebaseMicros(long julianMicros, final boolean failIfRebase) {
return rebaseTimestamp(julianMicros, failIfRebase, "Parquet");
}
static long rebaseInt96(long julianMicros, final boolean failIfRebase) {
return rebaseTimestamp(julianMicros, failIfRebase, "Parquet INT96");
}
/**
* Reads `total` values from this columnReader into column.
*/
void readBatch(int total, WritableColumnVector column) throws IOException {
int rowId = 0;
WritableColumnVector dictionaryIds = null;
ParquetVectorUpdater updater = updaterFactory.getUpdater(descriptor, column.dataType());
if (dictionary != null) {
// SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to
// decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded
@ -255,7 +183,7 @@ public class VectorizedColumnReader {
}
int num = Math.min(total, leftInPage);
PrimitiveType.PrimitiveTypeName typeName =
descriptor.getPrimitiveType().getPrimitiveTypeName();
descriptor.getPrimitiveType().getPrimitiveTypeName();
if (isCurrentPageDictionaryEncoded) {
// Read and decode dictionary ids.
defColumn.readIntegers(
@ -279,53 +207,26 @@ public class VectorizedColumnReader {
// We require a long value, but we need to use dictionary to decode the original
// signed int first
boolean isUnsignedInt32 = isUnsignedIntTypeMatched(32);
boolean isUnsignedInt32 = updaterFactory.isUnsignedIntTypeMatched(32);
// We require a decimal value, but we need to use dictionary to decode the original
// signed long first
boolean isUnsignedInt64 = isUnsignedIntTypeMatched(64);
boolean isUnsignedInt64 = updaterFactory.isUnsignedIntTypeMatched(64);
boolean needTransform = castLongToInt || isUnsignedInt32 || isUnsignedInt64;
column.setDictionary(new ParquetDictionary(dictionary, needTransform));
} else {
decodeDictionaryIds(rowId, num, column, dictionaryIds);
updater.decodeDictionaryIds(num, rowId, column, dictionaryIds, dictionary);
}
} else {
if (column.hasDictionary() && rowId != 0) {
// This batch already has dictionary encoded values but this new page is not. The batch
// does not support a mix of dictionary and not so we will decode the dictionary.
decodeDictionaryIds(0, rowId, column, column.getDictionaryIds());
updater.decodeDictionaryIds(rowId, 0, column, dictionaryIds, dictionary);
}
column.setDictionary(null);
switch (typeName) {
case BOOLEAN:
readBooleanBatch(rowId, num, column);
break;
case INT32:
readIntBatch(rowId, num, column);
break;
case INT64:
readLongBatch(rowId, num, column);
break;
case INT96:
readBinaryBatch(rowId, num, column);
break;
case FLOAT:
readFloatBatch(rowId, num, column);
break;
case DOUBLE:
readDoubleBatch(rowId, num, column);
break;
case BINARY:
readBinaryBatch(rowId, num, column);
break;
case FIXED_LEN_BYTE_ARRAY:
readFixedLenByteArrayBatch(
rowId, num, column, descriptor.getPrimitiveType().getTypeLength());
break;
default:
throw new IOException("Unsupported type: " + typeName);
}
VectorizedValuesReader valuesReader = (VectorizedValuesReader) dataColumn;
defColumn.readBatch(num, rowId, column, maxDefLevel, valuesReader, updater);
}
valuesRead += num;
@ -334,457 +235,6 @@ public class VectorizedColumnReader {
}
}
private boolean shouldConvertTimestamps() {
return convertTz != null && !convertTz.equals(UTC);
}
/**
* Helper function to construct exception for parquet schema mismatch.
*/
private SchemaColumnConvertNotSupportedException constructConvertNotSupportedException(
ColumnDescriptor descriptor,
WritableColumnVector column) {
return new SchemaColumnConvertNotSupportedException(
Arrays.toString(descriptor.getPath()),
descriptor.getPrimitiveType().getPrimitiveTypeName().toString(),
column.dataType().catalogString());
}
/**
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
*/
private void decodeDictionaryIds(
int rowId,
int num,
WritableColumnVector column,
WritableColumnVector dictionaryIds) {
switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) {
case INT32:
if (column.dataType() == DataTypes.IntegerType ||
canReadAsIntDecimal(column.dataType()) ||
(column.dataType() == DataTypes.DateType && "CORRECTED".equals(datetimeRebaseMode))) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
column.putInt(i, dictionary.decodeToInt(dictionaryIds.getDictId(i)));
}
}
} else if (column.dataType() == DataTypes.LongType) {
// In `ParquetToSparkSchemaConverter`, we map parquet UINT32 to our LongType.
// For unsigned int32, it stores as dictionary encoded signed int32 in Parquet
// whenever dictionary is available.
// Here we eagerly decode it to the original signed int value then convert to
// long(unit32).
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
column.putLong(i,
Integer.toUnsignedLong(dictionary.decodeToInt(dictionaryIds.getDictId(i))));
}
}
} else if (column.dataType() == DataTypes.ByteType) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getDictId(i)));
}
}
} else if (column.dataType() == DataTypes.ShortType) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getDictId(i)));
}
}
} else if (column.dataType() == DataTypes.DateType) {
final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode);
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
int julianDays = dictionary.decodeToInt(dictionaryIds.getDictId(i));
column.putInt(i, rebaseDays(julianDays, failIfRebase));
}
}
} else {
throw constructConvertNotSupportedException(descriptor, column);
}
break;
case INT64:
if (column.dataType() == DataTypes.LongType ||
canReadAsLongDecimal(column.dataType()) ||
(isTimestampTypeMatched(TimeUnit.MICROS) &&
"CORRECTED".equals(datetimeRebaseMode))) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i)));
}
}
} else if (isUnsignedIntTypeMatched(64)) {
// In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0).
// For unsigned int64, it stores as dictionary encoded signed int64 in Parquet
// whenever dictionary is available.
// Here we eagerly decode it to the original signed int64(long) value then convert to
// BigInteger.
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
long signed = dictionary.decodeToLong(dictionaryIds.getDictId(i));
byte[] unsigned = new BigInteger(Long.toUnsignedString(signed)).toByteArray();
column.putByteArray(i, unsigned);
}
}
} else if (isTimestampTypeMatched(TimeUnit.MILLIS)) {
if ("CORRECTED".equals(datetimeRebaseMode)) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
long gregorianMillis = dictionary.decodeToLong(dictionaryIds.getDictId(i));
column.putLong(i, DateTimeUtils.millisToMicros(gregorianMillis));
}
}
} else {
final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode);
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
long julianMillis = dictionary.decodeToLong(dictionaryIds.getDictId(i));
long julianMicros = DateTimeUtils.millisToMicros(julianMillis);
column.putLong(i, rebaseMicros(julianMicros, failIfRebase));
}
}
}
} else if (isTimestampTypeMatched(TimeUnit.MICROS)) {
final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode);
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
long julianMicros = dictionary.decodeToLong(dictionaryIds.getDictId(i));
column.putLong(i, rebaseMicros(julianMicros, failIfRebase));
}
}
} else {
throw constructConvertNotSupportedException(descriptor, column);
}
break;
case FLOAT:
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getDictId(i)));
}
}
break;
case DOUBLE:
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getDictId(i)));
}
}
break;
case INT96:
if (column.dataType() == DataTypes.TimestampType) {
final boolean failIfRebase = "EXCEPTION".equals(int96RebaseMode);
if (!shouldConvertTimestamps()) {
if ("CORRECTED".equals(int96RebaseMode)) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v));
}
}
} else {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(v);
long gregorianMicros = rebaseInt96(julianMicros, failIfRebase);
column.putLong(i, gregorianMicros);
}
}
}
} else {
if ("CORRECTED".equals(int96RebaseMode)) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
long gregorianMicros = ParquetRowConverter.binaryToSQLTimestamp(v);
long adjTime = DateTimeUtils.convertTz(gregorianMicros, convertTz, UTC);
column.putLong(i, adjTime);
}
}
} else {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(v);
long gregorianMicros = rebaseInt96(julianMicros, failIfRebase);
long adjTime = DateTimeUtils.convertTz(gregorianMicros, convertTz, UTC);
column.putLong(i, adjTime);
}
}
}
}
} else {
throw constructConvertNotSupportedException(descriptor, column);
}
break;
case BINARY:
// TODO: this is incredibly inefficient as it blows up the dictionary right here. We
// need to do this better. We should probably add the dictionary data to the ColumnVector
// and reuse it across batches. This should mean adding a ByteArray would just update
// the length and offset.
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
column.putByteArray(i, v.getBytes());
}
}
break;
case FIXED_LEN_BYTE_ARRAY:
// DecimalType written in the legacy mode
if (canReadAsIntDecimal(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
column.putInt(i, (int) ParquetRowConverter.binaryToUnscaledLong(v));
}
}
} else if (canReadAsLongDecimal(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
column.putLong(i, ParquetRowConverter.binaryToUnscaledLong(v));
}
}
} else if (canReadAsBinaryDecimal(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
column.putByteArray(i, v.getBytes());
}
}
} else {
throw constructConvertNotSupportedException(descriptor, column);
}
break;
default:
throw new UnsupportedOperationException(
"Unsupported type: " + descriptor.getPrimitiveType().getPrimitiveTypeName());
}
}
/**
* For all the read*Batch functions, reads `num` values from this columnReader into column. It
* is guaranteed that num is smaller than the number of values left in the current page.
*/
private void readBooleanBatch(int rowId, int num, WritableColumnVector column)
throws IOException {
if (column.dataType() != DataTypes.BooleanType) {
throw constructConvertNotSupportedException(descriptor, column);
}
defColumn.readBooleans(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
}
private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.IntegerType ||
canReadAsIntDecimal(column.dataType())) {
defColumn.readIntegers(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.LongType) {
// In `ParquetToSparkSchemaConverter`, we map parquet UINT32 to our LongType.
// For unsigned int32, it stores as plain signed int32 in Parquet when dictionary fallbacks.
// We read them as long values.
defColumn.readUnsignedIntegers(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.ByteType) {
defColumn.readBytes(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.ShortType) {
defColumn.readShorts(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.DateType ) {
if ("CORRECTED".equals(datetimeRebaseMode)) {
defColumn.readIntegers(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else {
boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode);
defColumn.readIntegersWithRebase(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, failIfRebase);
}
} else {
throw constructConvertNotSupportedException(descriptor, column);
}
}
private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
if (column.dataType() == DataTypes.LongType ||
canReadAsLongDecimal(column.dataType())) {
defColumn.readLongs(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn,
DecimalType.is32BitDecimalType(column.dataType()));
} else if (isUnsignedIntTypeMatched(64)) {
// In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0).
// For unsigned int64, it stores as plain signed int64 in Parquet when dictionary fallbacks.
// We read them as decimal values.
defColumn.readUnsignedLongs(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (isTimestampTypeMatched(TimeUnit.MICROS)) {
if ("CORRECTED".equals(datetimeRebaseMode)) {
defColumn.readLongs(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, false);
} else {
boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode);
defColumn.readLongsWithRebase(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, failIfRebase);
}
} else if (isTimestampTypeMatched(TimeUnit.MILLIS)) {
if ("CORRECTED".equals(datetimeRebaseMode)) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putLong(rowId + i, DateTimeUtils.millisToMicros(dataColumn.readLong()));
} else {
column.putNull(rowId + i);
}
}
} else {
final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode);
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
long julianMicros = DateTimeUtils.millisToMicros(dataColumn.readLong());
column.putLong(rowId + i, rebaseMicros(julianMicros, failIfRebase));
} else {
column.putNull(rowId + i);
}
}
}
} else {
throw constructConvertNotSupportedException(descriptor, column);
}
}
private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: support implicit cast to double?
if (column.dataType() == DataTypes.FloatType) {
defColumn.readFloats(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else {
throw constructConvertNotSupportedException(descriptor, column);
}
}
private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.DoubleType) {
defColumn.readDoubles(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else {
throw constructConvertNotSupportedException(descriptor, column);
}
}
private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType
|| canReadAsBinaryDecimal(column.dataType())) {
defColumn.readBinarys(num, column, rowId, maxDefLevel, data);
} else if (column.dataType() == DataTypes.TimestampType) {
final boolean failIfRebase = "EXCEPTION".equals(int96RebaseMode);
if (!shouldConvertTimestamps()) {
if ("CORRECTED".equals(int96RebaseMode)) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
// Read 12 bytes for INT96
long gregorianMicros = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12));
column.putLong(rowId + i, gregorianMicros);
} else {
column.putNull(rowId + i);
}
}
} else {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
// Read 12 bytes for INT96
long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12));
long gregorianMicros = rebaseInt96(julianMicros, failIfRebase);
column.putLong(rowId + i, gregorianMicros);
} else {
column.putNull(rowId + i);
}
}
}
} else {
if ("CORRECTED".equals(int96RebaseMode)) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
// Read 12 bytes for INT96
long gregorianMicros = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12));
long adjTime = DateTimeUtils.convertTz(gregorianMicros, convertTz, UTC);
column.putLong(rowId + i, adjTime);
} else {
column.putNull(rowId + i);
}
}
} else {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
// Read 12 bytes for INT96
long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12));
long gregorianMicros = rebaseInt96(julianMicros, failIfRebase);
long adjTime = DateTimeUtils.convertTz(gregorianMicros, convertTz, UTC);
column.putLong(rowId + i, adjTime);
} else {
column.putNull(rowId + i);
}
}
}
}
} else {
throw constructConvertNotSupportedException(descriptor, column);
}
}
private void readFixedLenByteArrayBatch(
int rowId,
int num,
WritableColumnVector column,
int arrayLen) {
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (canReadAsIntDecimal(column.dataType())) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putInt(rowId + i,
(int) ParquetRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
} else {
column.putNull(rowId + i);
}
}
} else if (canReadAsLongDecimal(column.dataType())) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putLong(rowId + i,
ParquetRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
} else {
column.putNull(rowId + i);
}
}
} else if (canReadAsBinaryDecimal(column.dataType())) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putByteArray(rowId + i, data.readBinary(arrayLen).getBytes());
} else {
column.putNull(rowId + i);
}
}
} else {
throw constructConvertNotSupportedException(descriptor, column);
}
}
private void readPage() {
DataPage page = pageReader.readPage();
// TODO: Why is this a visitor?
@ -881,15 +331,4 @@ public class VectorizedColumnReader {
throw new IOException("could not read page " + page + " in col " + descriptor, e);
}
}
private boolean isTimestampTypeMatched(TimeUnit unit) {
return logicalTypeAnnotation instanceof TimestampLogicalTypeAnnotation &&
((TimestampLogicalTypeAnnotation) logicalTypeAnnotation).getUnit() == unit;
}
private boolean isUnsignedIntTypeMatched(int bitWidth) {
return logicalTypeAnnotation instanceof IntLogicalTypeAnnotation &&
!((IntLogicalTypeAnnotation) logicalTypeAnnotation).isSigned() &&
((IntLogicalTypeAnnotation) logicalTypeAnnotation).getBitWidth() == bitWidth;
}
}

View file

@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.datasources.parquet;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import org.apache.parquet.Preconditions;
@ -170,438 +169,36 @@ public final class VectorizedRleValuesReader extends ValuesReader
* c[rowId] = null;
* }
*/
public void readIntegers(
public void readBatch(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data) throws IOException {
int offset,
WritableColumnVector values,
int maxDefinitionLevel,
VectorizedValuesReader valueReader,
ParquetVectorUpdater updater) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readIntegers(n, c, rowId);
if (currentValue == maxDefinitionLevel) {
updater.updateBatch(n, offset, values, valueReader);
} else {
c.putNulls(rowId, n);
values.putNulls(offset, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putInt(rowId + i, data.readInteger());
if (currentBuffer[currentBufferIdx++] == maxDefinitionLevel) {
updater.update(offset + i, values, valueReader);
} else {
c.putNull(rowId + i);
values.putNull(offset + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
// A fork of `readIntegers`, reading the signed integers as unsigned in long type
public void readUnsignedIntegers(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readUnsignedIntegers(n, c, rowId);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putLong(rowId + i, Integer.toUnsignedLong(data.readInteger()));
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
// A fork of `readIntegers`, which rebases the date int value (days) before filling
// the Spark column vector.
public void readIntegersWithRebase(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data,
final boolean failIfRebase) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readIntegersWithRebase(n, c, rowId, failIfRebase);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
int julianDays = data.readInteger();
c.putInt(rowId + i, VectorizedColumnReader.rebaseDays(julianDays, failIfRebase));
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
// TODO: can this code duplication be removed without a perf penalty?
public void readBooleans(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readBooleans(n, c, rowId);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putBoolean(rowId + i, data.readBoolean());
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
public void readBytes(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readBytes(n, c, rowId);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putByte(rowId + i, data.readByte());
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
public void readShorts(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readShorts(n, c, rowId);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putShort(rowId + i, data.readShort());
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
public void readLongs(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data,
boolean downCastLongToInt) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
if (downCastLongToInt) {
for (int i = 0; i < n; ++i) {
c.putInt(rowId + i, (int) data.readLong());
}
} else {
data.readLongs(n, c, rowId);
}
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
// code repeated for performance
if (downCastLongToInt) {
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putInt(rowId + i, (int) data.readLong());
} else {
c.putNull(rowId + i);
}
}
} else {
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putLong(rowId + i, data.readLong());
} else {
c.putNull(rowId + i);
}
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
public void readUnsignedLongs(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readUnsignedLongs(n, c, rowId);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
byte[] bytes = new BigInteger(Long.toUnsignedString(data.readLong())).toByteArray();
c.putByteArray(rowId + i, bytes);
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
// A fork of `readLongs`, which rebases the timestamp long value (microseconds) before filling
// the Spark column vector.
public void readLongsWithRebase(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data,
final boolean failIfRebase) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readLongsWithRebase(n, c, rowId, failIfRebase);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
long julianMicros = data.readLong();
c.putLong(rowId + i, VectorizedColumnReader.rebaseMicros(julianMicros, failIfRebase));
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
public void readFloats(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readFloats(n, c, rowId);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putFloat(rowId + i, data.readFloat());
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
public void readDoubles(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readDoubles(n, c, rowId);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putDouble(rowId + i, data.readDouble());
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
public void readBinarys(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readBinary(n, c, rowId);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
data.readBinary(1, c, rowId + i);
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
offset += n;
left -= n;
currentCount -= n;
}
@ -796,10 +393,6 @@ public final class VectorizedRleValuesReader extends ValuesReader
throw new RuntimeException("Unreachable");
}
private int ceil8(int value) {
return (value + 7) / 8;
}
/**
* Reads the next group.
*/

View file

@ -496,6 +496,28 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
}
}
test("SPARK-35640: read binary as timestamp should throw schema incompatible error") {
val data = (1 to 4).map(i => Tuple1(i.toString))
val readSchema = StructType(Seq(StructField("_1", DataTypes.TimestampType)))
withParquetFile(data) { path =>
val errMsg = intercept[Exception](spark.read.schema(readSchema).parquet(path).collect())
.getMessage
assert(errMsg.contains("Parquet column cannot be converted in file"))
}
}
test("SPARK-35640: int as long should throw schema incompatible error") {
val data = (1 to 4).map(i => Tuple1(i))
val readSchema = StructType(Seq(StructField("_1", DataTypes.LongType)))
withParquetFile(data) { path =>
val errMsg = intercept[Exception](spark.read.schema(readSchema).parquet(path).collect())
.getMessage
assert(errMsg.contains("Parquet column cannot be converted in file"))
}
}
test("write metadata") {
val hadoopConf = spark.sessionState.newHadoopConf()
withTempPath { file =>