[SPARK-34786][SQL] Read Parquet unsigned int64 logical type that stored as signed int64 physical type to decimal(20, 0)

### What changes were proposed in this pull request?

A companion PR for SPARK-34817, when we handle the unsigned int(<=32) logical types. In this PR, we map the unsigned int64 to decimal(20, 0) for better compatibility.

### Why are the changes needed?

Spark won't have unsigned types, but spark should be able to read existing parquet files written by other systems that support unsigned types for better compatibility.

### Does this PR introduce _any_ user-facing change?

yes, we can read parquet uint64 now

### How was this patch tested?

new unit tests

Closes #31960 from yaooqinn/SPARK-34786-2.

Authored-by: Kent Yao <yao@apache.org>
Signed-off-by: Kent Yao <yao@apache.org>
This commit is contained in:
Kent Yao 2021-03-26 09:54:19 +08:00
parent 5ffc3897e0
commit 820b465886
9 changed files with 137 additions and 13 deletions

View file

@ -1268,10 +1268,6 @@ private[spark] object QueryCompilationErrors {
s"createTableColumnTypes option column $col not found in schema ${schema.catalogString}")
}
def parquetTypeUnsupportedError(parquetType: String): Throwable = {
new AnalysisException(s"Parquet type not supported: $parquetType")
}
def parquetTypeUnsupportedYetError(parquetType: String): Throwable = {
new AnalysisException(s"Parquet type not yet supported: $parquetType")
}

View file

@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet;
import org.apache.spark.sql.execution.vectorized.Dictionary;
import java.math.BigInteger;
public final class ParquetDictionary implements Dictionary {
private org.apache.parquet.column.Dictionary dictionary;
private boolean needTransform = false;
@ -61,6 +63,14 @@ public final class ParquetDictionary implements Dictionary {
@Override
public byte[] decodeToBinary(int id) {
return dictionary.decodeToBinary(id).getBytes();
if (needTransform) {
// For unsigned int64, it stores as dictionary encoded signed int64 in Parquet
// whenever dictionary is available.
// Here we lazily decode it to the original signed long value then convert to decimal(20, 0).
long signed = dictionary.decodeToLong(id);
return new BigInteger(Long.toUnsignedString(signed)).toByteArray();
} else {
return dictionary.decodeToBinary(id).getBytes();
}
}
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources.parquet;
import java.io.IOException;
import java.math.BigInteger;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.util.Arrays;
@ -290,8 +291,12 @@ public class VectorizedColumnReader {
// signed int first
boolean isUnsignedInt32 = primitiveType.getOriginalType() == OriginalType.UINT_32;
column.setDictionary(
new ParquetDictionary(dictionary, castLongToInt || isUnsignedInt32));
// We require a decimal value, but we need to use dictionary to decode the original
// signed long first
boolean isUnsignedInt64 = primitiveType.getOriginalType() == OriginalType.UINT_64;
boolean needTransform = castLongToInt || isUnsignedInt32 || isUnsignedInt64;
column.setDictionary(new ParquetDictionary(dictionary, needTransform));
} else {
decodeDictionaryIds(rowId, num, column, dictionaryIds);
}
@ -420,6 +425,19 @@ public class VectorizedColumnReader {
column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i)));
}
}
} else if (originalType == OriginalType.UINT_64) {
// In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0).
// For unsigned int64, it stores as dictionary encoded signed int64 in Parquet
// whenever dictionary is available.
// Here we eagerly decode it to the original signed int64(long) value then convert to
// BigInteger.
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
long signed = dictionary.decodeToLong(dictionaryIds.getDictId(i));
byte[] unsigned = new BigInteger(Long.toUnsignedString(signed)).toByteArray();
column.putByteArray(i, unsigned);
}
}
} else if (originalType == OriginalType.TIMESTAMP_MILLIS) {
if ("CORRECTED".equals(datetimeRebaseMode)) {
for (int i = rowId; i < rowId + num; ++i) {
@ -582,7 +600,7 @@ public class VectorizedColumnReader {
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.LongType) {
// In `ParquetToSparkSchemaConverter`, we map parquet UINT32 to our LongType.
// For unsigned int32, it stores as plain signed int32 in Parquet when dictionary fall backs.
// For unsigned int32, it stores as plain signed int32 in Parquet when dictionary fallbacks.
// We read them as long values.
defColumn.readUnsignedIntegers(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
@ -613,6 +631,12 @@ public class VectorizedColumnReader {
defColumn.readLongs(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn,
DecimalType.is32BitDecimalType(column.dataType()));
} else if (originalType == OriginalType.UINT_64) {
// In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0).
// For unsigned int64, it stores as plain signed int64 in Parquet when dictionary fallbacks.
// We read them as decimal values.
defColumn.readUnsignedLongs(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (originalType == OriginalType.TIMESTAMP_MICROS) {
if ("CORRECTED".equals(datetimeRebaseMode)) {
defColumn.readLongs(

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.datasources.parquet;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@ -139,6 +140,16 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
}
}
@Override
public final void readUnsignedLongs(int total, WritableColumnVector c, int rowId) {
int requiredBytes = total * 8;
ByteBuffer buffer = getBuffer(requiredBytes);
for (int i = 0; i < total; i += 1) {
c.putByteArray(
rowId + i, new BigInteger(Long.toUnsignedString(buffer.getLong())).toByteArray());
}
}
// A fork of `readLongs` to rebase the timestamp values. For performance reasons, this method
// iterates the values twice: check if we need to rebase first, then go to the optimized branch
// if rebase is not needed.

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources.parquet;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import org.apache.parquet.Preconditions;
@ -433,6 +434,41 @@ public final class VectorizedRleValuesReader extends ValuesReader
}
}
public void readUnsignedLongs(
int total,
WritableColumnVector c,
int rowId,
int level,
VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
data.readUnsignedLongs(n, c, rowId);
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
byte[] bytes = new BigInteger(Long.toUnsignedString(data.readLong())).toByteArray();
c.putByteArray(rowId + i, bytes);
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}
// A fork of `readLongs`, which rebases the timestamp long value (microseconds) before filling
// the Spark column vector.
public void readLongsWithRebase(
@ -642,6 +678,11 @@ public final class VectorizedRleValuesReader extends ValuesReader
throw new UnsupportedOperationException("only readInts is valid.");
}
@Override
public void readUnsignedLongs(int total, WritableColumnVector c, int rowId) {
throw new UnsupportedOperationException("only readInts is valid.");
}
@Override
public void readIntegersWithRebase(
int total, WritableColumnVector c, int rowId, boolean failIfRebase) {

View file

@ -42,6 +42,7 @@ public interface VectorizedValuesReader {
void readIntegers(int total, WritableColumnVector c, int rowId);
void readIntegersWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase);
void readUnsignedIntegers(int total, WritableColumnVector c, int rowId);
void readUnsignedLongs(int total, WritableColumnVector c, int rowId);
void readLongs(int total, WritableColumnVector c, int rowId);
void readLongsWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase);
void readFloats(int total, WritableColumnVector c, int rowId);

View file

@ -285,6 +285,14 @@ private[parquet] class ParquetRowConverter(
metadata.getPrecision, metadata.getScale, updater)
}
// For unsigned int64
case _: DecimalType if parquetType.getOriginalType == OriginalType.UINT_64 =>
new ParquetPrimitiveConverter(updater) {
override def addLong(value: Long): Unit = {
updater.set(Decimal(java.lang.Long.toUnsignedString(value)))
}
}
// For INT64 backed decimals
case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 =>
val metadata = parquetType.asPrimitiveType().getDecimalMetadata

View file

@ -98,9 +98,6 @@ class ParquetToSparkSchemaConverter(
def typeString =
if (originalType == null) s"$typeName" else s"$typeName ($originalType)"
def typeNotSupported() =
throw QueryCompilationErrors.parquetTypeUnsupportedError(typeString)
def typeNotImplemented() =
throw QueryCompilationErrors.parquetTypeUnsupportedYetError(typeString)
@ -144,7 +141,7 @@ class ParquetToSparkSchemaConverter(
originalType match {
case INT_64 | null => LongType
case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS)
case UINT_64 => typeNotSupported()
case UINT_64 => DecimalType.LongDecimal
case TIMESTAMP_MICROS => TimestampType
case TIMESTAMP_MILLIS => TimestampType
case _ => illegalType()

View file

@ -24,6 +24,7 @@ import scala.collection.mutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import com.google.common.primitives.UnsignedLong
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
import org.apache.parquet.column.{Encoding, ParquetProperties}
@ -295,10 +296,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
| required INT32 a(UINT_8);
| required INT32 b(UINT_16);
| required INT32 c(UINT_32);
| required INT64 d(UINT_64);
|}
""".stripMargin)
val expectedSparkTypes = Seq(ShortType, IntegerType, LongType)
val expectedSparkTypes = Seq(ShortType, IntegerType, LongType, DecimalType.LongDecimal)
withTempPath { location =>
val path = new Path(location.getCanonicalPath)
@ -459,6 +461,40 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
}
}
test("SPARK-34817: Read UINT_64 as Decimal from parquet") {
Seq(true, false).foreach { dictionaryEnabled =>
def makeRawParquetFile(path: Path): Unit = {
val schemaStr =
"""message root {
| required INT64 a(UINT_64);
|}
""".stripMargin
val schema = MessageTypeParser.parseMessageType(schemaStr)
val writer = createParquetWriter(schema, path, dictionaryEnabled)
val factory = new SimpleGroupFactory(schema)
(-500 until 500).foreach { i =>
val group = factory.newGroup()
.append("a", i % 100L)
writer.write(group)
}
writer.close()
}
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "part-r-0.parquet")
makeRawParquetFile(path)
readParquetFile(path.toString) { df =>
checkAnswer(df, (-500 until 500).map { i =>
val bi = UnsignedLong.fromLongBits(i % 100L).bigIntegerValue()
Row(new java.math.BigDecimal(bi))
})
}
}
}
}
test("write metadata") {
val hadoopConf = spark.sessionState.newHadoopConf()
withTempPath { file =>