[SPARK-34786][SQL] Read Parquet unsigned int64 logical type that stored as signed int64 physical type to decimal(20, 0)
### What changes were proposed in this pull request? A companion PR for SPARK-34817, when we handle the unsigned int(<=32) logical types. In this PR, we map the unsigned int64 to decimal(20, 0) for better compatibility. ### Why are the changes needed? Spark won't have unsigned types, but spark should be able to read existing parquet files written by other systems that support unsigned types for better compatibility. ### Does this PR introduce _any_ user-facing change? yes, we can read parquet uint64 now ### How was this patch tested? new unit tests Closes #31960 from yaooqinn/SPARK-34786-2. Authored-by: Kent Yao <yao@apache.org> Signed-off-by: Kent Yao <yao@apache.org>
This commit is contained in:
parent
5ffc3897e0
commit
820b465886
|
@ -1268,10 +1268,6 @@ private[spark] object QueryCompilationErrors {
|
|||
s"createTableColumnTypes option column $col not found in schema ${schema.catalogString}")
|
||||
}
|
||||
|
||||
def parquetTypeUnsupportedError(parquetType: String): Throwable = {
|
||||
new AnalysisException(s"Parquet type not supported: $parquetType")
|
||||
}
|
||||
|
||||
def parquetTypeUnsupportedYetError(parquetType: String): Throwable = {
|
||||
new AnalysisException(s"Parquet type not yet supported: $parquetType")
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet;
|
|||
|
||||
import org.apache.spark.sql.execution.vectorized.Dictionary;
|
||||
|
||||
import java.math.BigInteger;
|
||||
|
||||
public final class ParquetDictionary implements Dictionary {
|
||||
private org.apache.parquet.column.Dictionary dictionary;
|
||||
private boolean needTransform = false;
|
||||
|
@ -61,6 +63,14 @@ public final class ParquetDictionary implements Dictionary {
|
|||
|
||||
@Override
|
||||
public byte[] decodeToBinary(int id) {
|
||||
return dictionary.decodeToBinary(id).getBytes();
|
||||
if (needTransform) {
|
||||
// For unsigned int64, it stores as dictionary encoded signed int64 in Parquet
|
||||
// whenever dictionary is available.
|
||||
// Here we lazily decode it to the original signed long value then convert to decimal(20, 0).
|
||||
long signed = dictionary.decodeToLong(id);
|
||||
return new BigInteger(Long.toUnsignedString(signed)).toByteArray();
|
||||
} else {
|
||||
return dictionary.decodeToBinary(id).getBytes();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.execution.datasources.parquet;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigInteger;
|
||||
import java.time.ZoneId;
|
||||
import java.time.ZoneOffset;
|
||||
import java.util.Arrays;
|
||||
|
@ -290,8 +291,12 @@ public class VectorizedColumnReader {
|
|||
// signed int first
|
||||
boolean isUnsignedInt32 = primitiveType.getOriginalType() == OriginalType.UINT_32;
|
||||
|
||||
column.setDictionary(
|
||||
new ParquetDictionary(dictionary, castLongToInt || isUnsignedInt32));
|
||||
// We require a decimal value, but we need to use dictionary to decode the original
|
||||
// signed long first
|
||||
boolean isUnsignedInt64 = primitiveType.getOriginalType() == OriginalType.UINT_64;
|
||||
|
||||
boolean needTransform = castLongToInt || isUnsignedInt32 || isUnsignedInt64;
|
||||
column.setDictionary(new ParquetDictionary(dictionary, needTransform));
|
||||
} else {
|
||||
decodeDictionaryIds(rowId, num, column, dictionaryIds);
|
||||
}
|
||||
|
@ -420,6 +425,19 @@ public class VectorizedColumnReader {
|
|||
column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i)));
|
||||
}
|
||||
}
|
||||
} else if (originalType == OriginalType.UINT_64) {
|
||||
// In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0).
|
||||
// For unsigned int64, it stores as dictionary encoded signed int64 in Parquet
|
||||
// whenever dictionary is available.
|
||||
// Here we eagerly decode it to the original signed int64(long) value then convert to
|
||||
// BigInteger.
|
||||
for (int i = rowId; i < rowId + num; ++i) {
|
||||
if (!column.isNullAt(i)) {
|
||||
long signed = dictionary.decodeToLong(dictionaryIds.getDictId(i));
|
||||
byte[] unsigned = new BigInteger(Long.toUnsignedString(signed)).toByteArray();
|
||||
column.putByteArray(i, unsigned);
|
||||
}
|
||||
}
|
||||
} else if (originalType == OriginalType.TIMESTAMP_MILLIS) {
|
||||
if ("CORRECTED".equals(datetimeRebaseMode)) {
|
||||
for (int i = rowId; i < rowId + num; ++i) {
|
||||
|
@ -582,7 +600,7 @@ public class VectorizedColumnReader {
|
|||
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
|
||||
} else if (column.dataType() == DataTypes.LongType) {
|
||||
// In `ParquetToSparkSchemaConverter`, we map parquet UINT32 to our LongType.
|
||||
// For unsigned int32, it stores as plain signed int32 in Parquet when dictionary fall backs.
|
||||
// For unsigned int32, it stores as plain signed int32 in Parquet when dictionary fallbacks.
|
||||
// We read them as long values.
|
||||
defColumn.readUnsignedIntegers(
|
||||
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
|
||||
|
@ -613,6 +631,12 @@ public class VectorizedColumnReader {
|
|||
defColumn.readLongs(
|
||||
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn,
|
||||
DecimalType.is32BitDecimalType(column.dataType()));
|
||||
} else if (originalType == OriginalType.UINT_64) {
|
||||
// In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0).
|
||||
// For unsigned int64, it stores as plain signed int64 in Parquet when dictionary fallbacks.
|
||||
// We read them as decimal values.
|
||||
defColumn.readUnsignedLongs(
|
||||
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
|
||||
} else if (originalType == OriginalType.TIMESTAMP_MICROS) {
|
||||
if ("CORRECTED".equals(datetimeRebaseMode)) {
|
||||
defColumn.readLongs(
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.apache.spark.sql.execution.datasources.parquet;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigInteger;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
|
||||
|
@ -139,6 +140,16 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void readUnsignedLongs(int total, WritableColumnVector c, int rowId) {
|
||||
int requiredBytes = total * 8;
|
||||
ByteBuffer buffer = getBuffer(requiredBytes);
|
||||
for (int i = 0; i < total; i += 1) {
|
||||
c.putByteArray(
|
||||
rowId + i, new BigInteger(Long.toUnsignedString(buffer.getLong())).toByteArray());
|
||||
}
|
||||
}
|
||||
|
||||
// A fork of `readLongs` to rebase the timestamp values. For performance reasons, this method
|
||||
// iterates the values twice: check if we need to rebase first, then go to the optimized branch
|
||||
// if rebase is not needed.
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.execution.datasources.parquet;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigInteger;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
import org.apache.parquet.Preconditions;
|
||||
|
@ -433,6 +434,41 @@ public final class VectorizedRleValuesReader extends ValuesReader
|
|||
}
|
||||
}
|
||||
|
||||
public void readUnsignedLongs(
|
||||
int total,
|
||||
WritableColumnVector c,
|
||||
int rowId,
|
||||
int level,
|
||||
VectorizedValuesReader data) throws IOException {
|
||||
int left = total;
|
||||
while (left > 0) {
|
||||
if (this.currentCount == 0) this.readNextGroup();
|
||||
int n = Math.min(left, this.currentCount);
|
||||
switch (mode) {
|
||||
case RLE:
|
||||
if (currentValue == level) {
|
||||
data.readUnsignedLongs(n, c, rowId);
|
||||
} else {
|
||||
c.putNulls(rowId, n);
|
||||
}
|
||||
break;
|
||||
case PACKED:
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (currentBuffer[currentBufferIdx++] == level) {
|
||||
byte[] bytes = new BigInteger(Long.toUnsignedString(data.readLong())).toByteArray();
|
||||
c.putByteArray(rowId + i, bytes);
|
||||
} else {
|
||||
c.putNull(rowId + i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
rowId += n;
|
||||
left -= n;
|
||||
currentCount -= n;
|
||||
}
|
||||
}
|
||||
|
||||
// A fork of `readLongs`, which rebases the timestamp long value (microseconds) before filling
|
||||
// the Spark column vector.
|
||||
public void readLongsWithRebase(
|
||||
|
@ -642,6 +678,11 @@ public final class VectorizedRleValuesReader extends ValuesReader
|
|||
throw new UnsupportedOperationException("only readInts is valid.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readUnsignedLongs(int total, WritableColumnVector c, int rowId) {
|
||||
throw new UnsupportedOperationException("only readInts is valid.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readIntegersWithRebase(
|
||||
int total, WritableColumnVector c, int rowId, boolean failIfRebase) {
|
||||
|
|
|
@ -42,6 +42,7 @@ public interface VectorizedValuesReader {
|
|||
void readIntegers(int total, WritableColumnVector c, int rowId);
|
||||
void readIntegersWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase);
|
||||
void readUnsignedIntegers(int total, WritableColumnVector c, int rowId);
|
||||
void readUnsignedLongs(int total, WritableColumnVector c, int rowId);
|
||||
void readLongs(int total, WritableColumnVector c, int rowId);
|
||||
void readLongsWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase);
|
||||
void readFloats(int total, WritableColumnVector c, int rowId);
|
||||
|
|
|
@ -285,6 +285,14 @@ private[parquet] class ParquetRowConverter(
|
|||
metadata.getPrecision, metadata.getScale, updater)
|
||||
}
|
||||
|
||||
// For unsigned int64
|
||||
case _: DecimalType if parquetType.getOriginalType == OriginalType.UINT_64 =>
|
||||
new ParquetPrimitiveConverter(updater) {
|
||||
override def addLong(value: Long): Unit = {
|
||||
updater.set(Decimal(java.lang.Long.toUnsignedString(value)))
|
||||
}
|
||||
}
|
||||
|
||||
// For INT64 backed decimals
|
||||
case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 =>
|
||||
val metadata = parquetType.asPrimitiveType().getDecimalMetadata
|
||||
|
|
|
@ -98,9 +98,6 @@ class ParquetToSparkSchemaConverter(
|
|||
def typeString =
|
||||
if (originalType == null) s"$typeName" else s"$typeName ($originalType)"
|
||||
|
||||
def typeNotSupported() =
|
||||
throw QueryCompilationErrors.parquetTypeUnsupportedError(typeString)
|
||||
|
||||
def typeNotImplemented() =
|
||||
throw QueryCompilationErrors.parquetTypeUnsupportedYetError(typeString)
|
||||
|
||||
|
@ -144,7 +141,7 @@ class ParquetToSparkSchemaConverter(
|
|||
originalType match {
|
||||
case INT_64 | null => LongType
|
||||
case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS)
|
||||
case UINT_64 => typeNotSupported()
|
||||
case UINT_64 => DecimalType.LongDecimal
|
||||
case TIMESTAMP_MICROS => TimestampType
|
||||
case TIMESTAMP_MILLIS => TimestampType
|
||||
case _ => illegalType()
|
||||
|
|
|
@ -24,6 +24,7 @@ import scala.collection.mutable
|
|||
import scala.reflect.ClassTag
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
|
||||
import com.google.common.primitives.UnsignedLong
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
|
||||
import org.apache.parquet.column.{Encoding, ParquetProperties}
|
||||
|
@ -295,10 +296,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
|
|||
| required INT32 a(UINT_8);
|
||||
| required INT32 b(UINT_16);
|
||||
| required INT32 c(UINT_32);
|
||||
| required INT64 d(UINT_64);
|
||||
|}
|
||||
""".stripMargin)
|
||||
|
||||
val expectedSparkTypes = Seq(ShortType, IntegerType, LongType)
|
||||
val expectedSparkTypes = Seq(ShortType, IntegerType, LongType, DecimalType.LongDecimal)
|
||||
|
||||
withTempPath { location =>
|
||||
val path = new Path(location.getCanonicalPath)
|
||||
|
@ -459,6 +461,40 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-34817: Read UINT_64 as Decimal from parquet") {
|
||||
Seq(true, false).foreach { dictionaryEnabled =>
|
||||
def makeRawParquetFile(path: Path): Unit = {
|
||||
val schemaStr =
|
||||
"""message root {
|
||||
| required INT64 a(UINT_64);
|
||||
|}
|
||||
""".stripMargin
|
||||
val schema = MessageTypeParser.parseMessageType(schemaStr)
|
||||
|
||||
val writer = createParquetWriter(schema, path, dictionaryEnabled)
|
||||
|
||||
val factory = new SimpleGroupFactory(schema)
|
||||
(-500 until 500).foreach { i =>
|
||||
val group = factory.newGroup()
|
||||
.append("a", i % 100L)
|
||||
writer.write(group)
|
||||
}
|
||||
writer.close()
|
||||
}
|
||||
|
||||
withTempDir { dir =>
|
||||
val path = new Path(dir.toURI.toString, "part-r-0.parquet")
|
||||
makeRawParquetFile(path)
|
||||
readParquetFile(path.toString) { df =>
|
||||
checkAnswer(df, (-500 until 500).map { i =>
|
||||
val bi = UnsignedLong.fromLongBits(i % 100L).bigIntegerValue()
|
||||
Row(new java.math.BigDecimal(bi))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("write metadata") {
|
||||
val hadoopConf = spark.sessionState.newHadoopConf()
|
||||
withTempPath { file =>
|
||||
|
|
Loading…
Reference in a new issue