From 820b465886b440b90f994dc587680b802f197915 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 26 Mar 2021 09:54:19 +0800 Subject: [PATCH] [SPARK-34786][SQL] Read Parquet unsigned int64 logical type that stored as signed int64 physical type to decimal(20, 0) ### What changes were proposed in this pull request? A companion PR for SPARK-34817, when we handle the unsigned int(<=32) logical types. In this PR, we map the unsigned int64 to decimal(20, 0) for better compatibility. ### Why are the changes needed? Spark won't have unsigned types, but spark should be able to read existing parquet files written by other systems that support unsigned types for better compatibility. ### Does this PR introduce _any_ user-facing change? yes, we can read parquet uint64 now ### How was this patch tested? new unit tests Closes #31960 from yaooqinn/SPARK-34786-2. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../sql/errors/QueryCompilationErrors.scala | 4 -- .../parquet/ParquetDictionary.java | 12 +++++- .../parquet/VectorizedColumnReader.java | 30 ++++++++++++-- .../parquet/VectorizedPlainValuesReader.java | 11 +++++ .../parquet/VectorizedRleValuesReader.java | 41 +++++++++++++++++++ .../parquet/VectorizedValuesReader.java | 1 + .../parquet/ParquetRowConverter.scala | 8 ++++ .../parquet/ParquetSchemaConverter.scala | 5 +-- .../datasources/parquet/ParquetIOSuite.scala | 38 ++++++++++++++++- 9 files changed, 137 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 4819715e48..e485a205f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1268,10 +1268,6 @@ private[spark] object QueryCompilationErrors { s"createTableColumnTypes option column $col not found in schema ${schema.catalogString}") } - def parquetTypeUnsupportedError(parquetType: String): Throwable = { - new AnalysisException(s"Parquet type not supported: $parquetType") - } - def parquetTypeUnsupportedYetError(parquetType: String): Throwable = { new AnalysisException(s"Parquet type not yet supported: $parquetType") } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java index ea15607bb8..6626f3fee9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet; import org.apache.spark.sql.execution.vectorized.Dictionary; +import java.math.BigInteger; + public final class ParquetDictionary implements Dictionary { private org.apache.parquet.column.Dictionary dictionary; private boolean needTransform = false; @@ -61,6 +63,14 @@ public final class ParquetDictionary implements Dictionary { @Override public byte[] decodeToBinary(int id) { - return dictionary.decodeToBinary(id).getBytes(); + if (needTransform) { + // For unsigned int64, it stores as dictionary encoded signed int64 in Parquet + // whenever dictionary is available. + // Here we lazily decode it to the original signed long value then convert to decimal(20, 0). + long signed = dictionary.decodeToLong(id); + return new BigInteger(Long.toUnsignedString(signed)).toByteArray(); + } else { + return dictionary.decodeToBinary(id).getBytes(); + } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index e091cbac41..672b73e94c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.math.BigInteger; import java.time.ZoneId; import java.time.ZoneOffset; import java.util.Arrays; @@ -290,8 +291,12 @@ public class VectorizedColumnReader { // signed int first boolean isUnsignedInt32 = primitiveType.getOriginalType() == OriginalType.UINT_32; - column.setDictionary( - new ParquetDictionary(dictionary, castLongToInt || isUnsignedInt32)); + // We require a decimal value, but we need to use dictionary to decode the original + // signed long first + boolean isUnsignedInt64 = primitiveType.getOriginalType() == OriginalType.UINT_64; + + boolean needTransform = castLongToInt || isUnsignedInt32 || isUnsignedInt64; + column.setDictionary(new ParquetDictionary(dictionary, needTransform)); } else { decodeDictionaryIds(rowId, num, column, dictionaryIds); } @@ -420,6 +425,19 @@ public class VectorizedColumnReader { column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i))); } } + } else if (originalType == OriginalType.UINT_64) { + // In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0). + // For unsigned int64, it stores as dictionary encoded signed int64 in Parquet + // whenever dictionary is available. + // Here we eagerly decode it to the original signed int64(long) value then convert to + // BigInteger. + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + long signed = dictionary.decodeToLong(dictionaryIds.getDictId(i)); + byte[] unsigned = new BigInteger(Long.toUnsignedString(signed)).toByteArray(); + column.putByteArray(i, unsigned); + } + } } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { if ("CORRECTED".equals(datetimeRebaseMode)) { for (int i = rowId; i < rowId + num; ++i) { @@ -582,7 +600,7 @@ public class VectorizedColumnReader { num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.LongType) { // In `ParquetToSparkSchemaConverter`, we map parquet UINT32 to our LongType. - // For unsigned int32, it stores as plain signed int32 in Parquet when dictionary fall backs. + // For unsigned int32, it stores as plain signed int32 in Parquet when dictionary fallbacks. // We read them as long values. defColumn.readUnsignedIntegers( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); @@ -613,6 +631,12 @@ public class VectorizedColumnReader { defColumn.readLongs( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, DecimalType.is32BitDecimalType(column.dataType())); + } else if (originalType == OriginalType.UINT_64) { + // In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0). + // For unsigned int64, it stores as plain signed int64 in Parquet when dictionary fallbacks. + // We read them as decimal values. + defColumn.readUnsignedLongs( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (originalType == OriginalType.TIMESTAMP_MICROS) { if ("CORRECTED".equals(datetimeRebaseMode)) { defColumn.readLongs( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 99beb0250a..595da20ad5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -139,6 +140,16 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori } } + @Override + public final void readUnsignedLongs(int total, WritableColumnVector c, int rowId) { + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + for (int i = 0; i < total; i += 1) { + c.putByteArray( + rowId + i, new BigInteger(Long.toUnsignedString(buffer.getLong())).toByteArray()); + } + } + // A fork of `readLongs` to rebase the timestamp values. For performance reasons, this method // iterates the values twice: check if we need to rebase first, then go to the optimized branch // if rebase is not needed. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 384bcb30a1..2eed66278b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.math.BigInteger; import java.nio.ByteBuffer; import org.apache.parquet.Preconditions; @@ -433,6 +434,41 @@ public final class VectorizedRleValuesReader extends ValuesReader } } + public void readUnsignedLongs( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) throws IOException { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readUnsignedLongs(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + byte[] bytes = new BigInteger(Long.toUnsignedString(data.readLong())).toByteArray(); + c.putByteArray(rowId + i, bytes); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + // A fork of `readLongs`, which rebases the timestamp long value (microseconds) before filling // the Spark column vector. public void readLongsWithRebase( @@ -642,6 +678,11 @@ public final class VectorizedRleValuesReader extends ValuesReader throw new UnsupportedOperationException("only readInts is valid."); } + @Override + public void readUnsignedLongs(int total, WritableColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + @Override public void readIntegersWithRebase( int total, WritableColumnVector c, int rowId, boolean failIfRebase) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index 9f5d944329..d09f750beb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -42,6 +42,7 @@ public interface VectorizedValuesReader { void readIntegers(int total, WritableColumnVector c, int rowId); void readIntegersWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase); void readUnsignedIntegers(int total, WritableColumnVector c, int rowId); + void readUnsignedLongs(int total, WritableColumnVector c, int rowId); void readLongs(int total, WritableColumnVector c, int rowId); void readLongsWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase); void readFloats(int total, WritableColumnVector c, int rowId); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 2c610ec539..0a1cca7ed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -285,6 +285,14 @@ private[parquet] class ParquetRowConverter( metadata.getPrecision, metadata.getScale, updater) } + // For unsigned int64 + case _: DecimalType if parquetType.getOriginalType == OriginalType.UINT_64 => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + updater.set(Decimal(java.lang.Long.toUnsignedString(value))) + } + } + // For INT64 backed decimals case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 => val metadata = parquetType.asPrimitiveType().getDecimalMetadata diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index ef094bdca0..8c4e0881e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -98,9 +98,6 @@ class ParquetToSparkSchemaConverter( def typeString = if (originalType == null) s"$typeName" else s"$typeName ($originalType)" - def typeNotSupported() = - throw QueryCompilationErrors.parquetTypeUnsupportedError(typeString) - def typeNotImplemented() = throw QueryCompilationErrors.parquetTypeUnsupportedYetError(typeString) @@ -144,7 +141,7 @@ class ParquetToSparkSchemaConverter( originalType match { case INT_64 | null => LongType case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) - case UINT_64 => typeNotSupported() + case UINT_64 => DecimalType.LongDecimal case TIMESTAMP_MICROS => TimestampType case TIMESTAMP_MILLIS => TimestampType case _ => illegalType() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 82e605fc9f..c787a4eff7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import com.google.common.primitives.UnsignedLong import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.column.{Encoding, ParquetProperties} @@ -295,10 +296,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession | required INT32 a(UINT_8); | required INT32 b(UINT_16); | required INT32 c(UINT_32); + | required INT64 d(UINT_64); |} """.stripMargin) - val expectedSparkTypes = Seq(ShortType, IntegerType, LongType) + val expectedSparkTypes = Seq(ShortType, IntegerType, LongType, DecimalType.LongDecimal) withTempPath { location => val path = new Path(location.getCanonicalPath) @@ -459,6 +461,40 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } + test("SPARK-34817: Read UINT_64 as Decimal from parquet") { + Seq(true, false).foreach { dictionaryEnabled => + def makeRawParquetFile(path: Path): Unit = { + val schemaStr = + """message root { + | required INT64 a(UINT_64); + |} + """.stripMargin + val schema = MessageTypeParser.parseMessageType(schemaStr) + + val writer = createParquetWriter(schema, path, dictionaryEnabled) + + val factory = new SimpleGroupFactory(schema) + (-500 until 500).foreach { i => + val group = factory.newGroup() + .append("a", i % 100L) + writer.write(group) + } + writer.close() + } + + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + makeRawParquetFile(path) + readParquetFile(path.toString) { df => + checkAnswer(df, (-500 until 500).map { i => + val bi = UnsignedLong.fromLongBits(i % 100L).bigIntegerValue() + Row(new java.math.BigDecimal(bi)) + }) + } + } + } + } + test("write metadata") { val hadoopConf = spark.sessionState.newHadoopConf() withTempPath { file =>