[SPARK-36891][SQL] Refactor SpecificParquetRecordReaderBase and add more coverage on vectorized Parquet decoding

### What changes were proposed in this pull request?

Add a new test suite `ParquetVectorizedSuite` to provide more coverage on vectorized Parquet decoding logic, with different combinations on column index, dictionary, batch size, page size, etc.

To facilitate the test, this also refactored `SpecificParquetRecordReaderBase` and makes the Parquet row group reader pluggable.

### Why are the changes needed?

Currently `ParquetIOSuite` and `ParquetColumnIndexSuite` only test on the high-level API which is insufficient, especially after the introduction of column index support, for which we want to cover various combinations involving row ranges, first row index, batch size, page size, etc.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added new test suite.

Closes #34149 from sunchao/SPARK-36891-parquet-test.

Authored-by: Chao Sun <sunchao@apple.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
Chao Sun 2021-10-01 23:35:23 -07:00 committed by Dongjoon Hyun
parent 25db6b45c7
commit 14d4ceeb73
6 changed files with 584 additions and 12 deletions

21
pom.xml
View file

@ -2343,6 +2343,27 @@
<version>${parquet.version}</version>
<scope>${parquet.deps.scope}</scope>
</dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-encoding</artifactId>
<version>${parquet.version}</version>
<scope>${parquet.test.deps.scope}</scope>
<classifier>tests</classifier>
</dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-common</artifactId>
<version>${parquet.version}</version>
<scope>${parquet.test.deps.scope}</scope>
<classifier>tests</classifier>
</dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-column</artifactId>
<version>${parquet.version}</version>
<scope>${parquet.test.deps.scope}</scope>
<classifier>tests</classifier>
</dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-hadoop</artifactId>

View file

@ -112,6 +112,24 @@
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-column</artifactId>
</dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-encoding</artifactId>
<scope>test</scope>
<classifier>tests</classifier>
</dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-common</artifactId>
<scope>test</scope>
<classifier>tests</classifier>
</dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-column</artifactId>
<scope>test</scope>
<classifier>tests</classifier>
</dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-hadoop</artifactId>

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources.parquet;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
@ -29,6 +30,8 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import com.google.common.annotations.VisibleForTesting;
import org.apache.parquet.column.page.PageReadStore;
import scala.Option;
import org.apache.hadoop.conf.Configuration;
@ -75,7 +78,7 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo
*/
protected long totalRowCount;
protected ParquetFileReader reader;
protected ParquetRowGroupReader reader;
@Override
public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext)
@ -88,18 +91,20 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo
.builder(configuration)
.withRange(split.getStart(), split.getStart() + split.getLength())
.build();
this.reader = new ParquetFileReader(HadoopInputFile.fromPath(file, configuration), options);
this.fileSchema = reader.getFileMetaData().getSchema();
Map<String, String> fileMetadata = reader.getFileMetaData().getKeyValueMetaData();
ParquetFileReader fileReader = new ParquetFileReader(
HadoopInputFile.fromPath(file, configuration), options);
this.reader = new ParquetRowGroupReaderImpl(fileReader);
this.fileSchema = fileReader.getFileMetaData().getSchema();
Map<String, String> fileMetadata = fileReader.getFileMetaData().getKeyValueMetaData();
ReadSupport<T> readSupport = getReadSupportInstance(getReadSupportClass(configuration));
ReadSupport.ReadContext readContext = readSupport.init(new InitContext(
taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema));
this.requestedSchema = readContext.getRequestedSchema();
reader.setRequestedSchema(requestedSchema);
fileReader.setRequestedSchema(requestedSchema);
String sparkRequestedSchemaString =
configuration.get(ParquetReadSupport$.MODULE$.SPARK_ROW_REQUESTED_SCHEMA());
this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString);
this.totalRowCount = reader.getFilteredRecordCount();
this.totalRowCount = fileReader.getFilteredRecordCount();
// For test purpose.
// If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read
@ -111,7 +116,7 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo
if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) {
@SuppressWarnings("unchecked")
AccumulatorV2<Integer, Integer> intAccum = (AccumulatorV2<Integer, Integer>) accu.get();
intAccum.add(reader.getRowGroups().size());
intAccum.add(fileReader.getRowGroups().size());
}
}
}
@ -155,8 +160,10 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo
.builder(config)
.withRange(0, length)
.build();
this.reader = ParquetFileReader.open(HadoopInputFile.fromPath(file, config), options);
this.fileSchema = reader.getFooter().getFileMetaData().getSchema();
ParquetFileReader fileReader = ParquetFileReader.open(
HadoopInputFile.fromPath(file, config), options);
this.reader = new ParquetRowGroupReaderImpl(fileReader);
this.fileSchema = fileReader.getFooter().getFileMetaData().getSchema();
if (columns == null) {
this.requestedSchema = fileSchema;
@ -175,9 +182,25 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo
this.requestedSchema = ParquetSchemaConverter.EMPTY_MESSAGE();
}
}
reader.setRequestedSchema(requestedSchema);
fileReader.setRequestedSchema(requestedSchema);
this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema);
this.totalRowCount = reader.getFilteredRecordCount();
this.totalRowCount = fileReader.getFilteredRecordCount();
}
@VisibleForTesting
protected void initialize(
MessageType fileSchema,
MessageType requestedSchema,
ParquetRowGroupReader rowGroupReader,
int totalRowCount) throws IOException {
this.reader = rowGroupReader;
this.fileSchema = fileSchema;
this.requestedSchema = requestedSchema;
Configuration config = new Configuration();
config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() , false);
config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(), false);
this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema);
this.totalRowCount = totalRowCount;
}
@Override
@ -222,4 +245,31 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo
throw new BadConfigurationException("could not instantiate read support class", e);
}
}
interface ParquetRowGroupReader extends Closeable {
/**
* Reads the next row group from this reader. Returns null if there is no more row group.
*/
PageReadStore readNextRowGroup() throws IOException;
}
private static class ParquetRowGroupReaderImpl implements ParquetRowGroupReader {
private final ParquetFileReader reader;
ParquetRowGroupReaderImpl(ParquetFileReader reader) {
this.reader = reader;
}
@Override
public PageReadStore readNextRowGroup() throws IOException {
return reader.readNextFilteredRowGroup();
}
@Override
public void close() throws IOException {
if (reader != null) {
reader.close();
}
}
}
}

View file

@ -22,10 +22,12 @@ import java.time.ZoneId;
import java.util.Arrays;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.page.PageReadStore;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type;
import org.apache.spark.memory.MemoryMode;
@ -165,6 +167,17 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
initializeInternal();
}
@VisibleForTesting
@Override
public void initialize(
MessageType fileSchema,
MessageType requestedSchema,
ParquetRowGroupReader rowGroupReader,
int totalRowCount) throws IOException {
super.initialize(fileSchema, requestedSchema, rowGroupReader, totalRowCount);
initializeInternal();
}
@Override
public void close() throws IOException {
if (columnarBatch != null) {
@ -320,7 +333,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
private void checkEndOfRowGroup() throws IOException {
if (rowsReturned != totalCountLoadedSoFar) return;
PageReadStore pages = reader.readNextFilteredRowGroup();
PageReadStore pages = reader.readNextRowGroup();
if (pages == null) {
throw new IOException("expecting more rows but reached last block. Read "
+ rowsReturned + " out of " + totalRowCount);

View file

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.parquet.column.page;
import java.util.Optional;
/**
* A hack to create Parquet data pages with customized first row index. We have to put it under
* 'org.apache.parquet.column.page' since the constructor of `DataPage` is package-private.
*/
public class TestDataPage extends DataPage {
private final DataPage wrapped;
public TestDataPage(DataPage wrapped, long firstRowIndex) {
super(wrapped.getCompressedSize(), wrapped.getUncompressedSize(), wrapped.getValueCount(),
firstRowIndex);
this.wrapped = wrapped;
}
@Override
public Optional<Integer> getIndexRowCount() {
return Optional.empty();
}
@Override
public <T> T accept(Visitor<T> visitor) {
return wrapped.accept(visitor);
}
}

View file

@ -0,0 +1,426 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.parquet
import java.util.{Optional, PrimitiveIterator}
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import org.apache.parquet.column.{ColumnDescriptor, ParquetProperties}
import org.apache.parquet.column.impl.ColumnWriteStoreV1
import org.apache.parquet.column.page._
import org.apache.parquet.column.page.mem.MemPageStore
import org.apache.parquet.io.ParquetDecodingException
import org.apache.parquet.io.api.Binary
import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ParquetRowGroupReader
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
/**
* A test suite on the vectorized Parquet reader. Unlike `ParquetIOSuite`, this focuses on
* low-level decoding logic covering column index, dictionary, different batch and page sizes, etc.
*/
class ParquetVectorizedSuite extends QueryTest with ParquetTest with SharedSparkSession {
private val VALUES: Seq[String] = ('a' to 'z').map(_.toString)
private val NUM_VALUES: Int = VALUES.length
private val BATCH_SIZE_CONFIGS: Seq[Int] = Seq(1, 3, 5, 7, 10, 20, 40)
private val PAGE_SIZE_CONFIGS: Seq[Seq[Int]] = Seq(Seq(6, 6, 7, 7), Seq(4, 9, 4, 9))
implicit def toStrings(ints: Seq[Int]): Seq[String] = ints.map(i => ('a' + i).toChar.toString)
test("primitive type - no column index") {
BATCH_SIZE_CONFIGS.foreach { batchSize =>
PAGE_SIZE_CONFIGS.foreach { pageSizes =>
Seq(true, false).foreach { dictionaryEnabled =>
testPrimitiveString(None, None, pageSizes, VALUES, batchSize,
dictionaryEnabled = dictionaryEnabled)
}
}
}
}
test("primitive type - column index with ranges") {
BATCH_SIZE_CONFIGS.foreach { batchSize =>
PAGE_SIZE_CONFIGS.foreach { pageSizes =>
Seq(true, false).foreach { dictionaryEnabled =>
var ranges = Seq((0L, 9L))
testPrimitiveString(None, Some(ranges), pageSizes, 0 to 9, batchSize,
dictionaryEnabled = dictionaryEnabled)
ranges = Seq((30, 50))
testPrimitiveString(None, Some(ranges), pageSizes, Seq.empty, batchSize,
dictionaryEnabled = dictionaryEnabled)
ranges = Seq((15, 25))
testPrimitiveString(None, Some(ranges), pageSizes, 15 to 19, batchSize,
dictionaryEnabled = dictionaryEnabled)
ranges = Seq((19, 20))
testPrimitiveString(None, Some(ranges), pageSizes, 19 to 20, batchSize,
dictionaryEnabled = dictionaryEnabled)
ranges = Seq((0, 3), (5, 7), (15, 18))
testPrimitiveString(None, Some(ranges), pageSizes,
toStrings(Seq(0, 1, 2, 3, 5, 6, 7, 15, 16, 17, 18)),
batchSize, dictionaryEnabled = dictionaryEnabled)
}
}
}
}
test("primitive type - column index with ranges and nulls") {
BATCH_SIZE_CONFIGS.foreach { batchSize =>
PAGE_SIZE_CONFIGS.foreach { pageSizes =>
Seq(true, false).foreach { dictionaryEnabled =>
val valuesWithNulls = VALUES.zipWithIndex.map {
case (v, i) => if (i % 2 == 0) null else v
}
testPrimitiveString(None, None, pageSizes, valuesWithNulls, batchSize, valuesWithNulls,
dictionaryEnabled)
val ranges = Seq((5L, 7L))
testPrimitiveString(None, Some(ranges), pageSizes, Seq("f", null, "h"),
batchSize, valuesWithNulls, dictionaryEnabled)
}
}
}
}
test("primitive type - column index with ranges and first row indexes") {
BATCH_SIZE_CONFIGS.foreach { batchSize =>
Seq(true, false).foreach { dictionaryEnabled =>
// Single page
val firstRowIndex = 10
var ranges = Seq((0L, 9L))
testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length),
Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled)
ranges = Seq((15, 25))
testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length),
5 to 15, batchSize, dictionaryEnabled = dictionaryEnabled)
ranges = Seq((15, 35))
testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length),
5 to 19, batchSize, dictionaryEnabled = dictionaryEnabled)
ranges = Seq((15, 39))
testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length),
5 to 19, batchSize, dictionaryEnabled = dictionaryEnabled)
// Row indexes: [ [10, 16), [20, 26), [30, 37), [40, 47) ]
// Values: [ [0, 6), [6, 12), [12, 19), [19, 26) ]
var pageSizes = Seq(6, 6, 7, 7)
var firstRowIndexes = Seq(10L, 20, 30, 40)
ranges = Seq((0L, 9L))
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled)
ranges = Seq((15, 25))
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
5 to 9, batchSize, dictionaryEnabled = dictionaryEnabled)
ranges = Seq((15, 35))
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
5 to 14, batchSize, dictionaryEnabled = dictionaryEnabled)
ranges = Seq((15, 60))
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
5 to 19, batchSize, dictionaryEnabled = dictionaryEnabled)
ranges = Seq((12, 22), (28, 38))
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
toStrings(Seq(2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18)), batchSize,
dictionaryEnabled = dictionaryEnabled)
// Row indexes: [ [10, 11), [40, 52), [100, 112), [200, 201) ]
// Values: [ [0, 1), [1, 13), [13, 25), [25, 26] ]
pageSizes = Seq(1, 12, 12, 1)
firstRowIndexes = Seq(10L, 40, 100, 200)
ranges = Seq((0L, 9L))
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled)
ranges = Seq((300, 350))
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled)
ranges = Seq((50, 80))
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
(11 to 12), batchSize, dictionaryEnabled = dictionaryEnabled)
ranges = Seq((0, 150))
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
0 to 24, batchSize, dictionaryEnabled = dictionaryEnabled)
// with nulls
val valuesWithNulls = VALUES.zipWithIndex.map {
case (v, i) => if (i % 2 == 0) null else v
}
ranges = Seq((20, 45)) // select values in [1, 5]
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
Seq("b", null, "d", null, "f"), batchSize, valuesWithNulls,
dictionaryEnabled = dictionaryEnabled)
ranges = Seq((8, 12), (80, 104))
testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
Seq(null, "n", null, "p", null, "r"), batchSize, valuesWithNulls,
dictionaryEnabled = dictionaryEnabled)
}
}
}
private def testPrimitiveString(
firstRowIndexesOpt: Option[Seq[Long]],
rangesOpt: Option[Seq[(Long, Long)]],
pageSizes: Seq[Int],
expectedValues: Seq[String],
batchSize: Int,
inputValues: Seq[String] = VALUES,
dictionaryEnabled: Boolean = false): Unit = {
assert(pageSizes.sum == inputValues.length)
firstRowIndexesOpt.foreach(a => assert(pageSizes.length == a.length))
val isRequiredStr = if (!expectedValues.contains(null)) "required" else "optional"
val parquetSchema: MessageType = MessageTypeParser.parseMessageType(
s"""message root {
| $isRequiredStr binary a(UTF8);
|}
|""".stripMargin
)
val maxDef = if (inputValues.contains(null)) 1 else 0
val ty = parquetSchema.asGroupType().getType("a").asPrimitiveType()
val cd = new ColumnDescriptor(Seq("a").toArray, ty, 0, maxDef)
val repetitionLevels = Array.fill[Int](inputValues.length)(0)
val definitionLevels = inputValues.map(v => if (v == null) 0 else 1)
val memPageStore = new MemPageStore(expectedValues.length)
var i = 0
val pageFirstRowIndexes = ArrayBuffer.empty[Long]
pageSizes.foreach { size =>
pageFirstRowIndexes += i
writeDataPage(cd, memPageStore, repetitionLevels.slice(i, i + size),
definitionLevels.slice(i, i + size), inputValues.slice(i, i + size), maxDef,
dictionaryEnabled)
i += size
}
checkAnswer(expectedValues.length, parquetSchema,
TestPageReadStore(memPageStore, firstRowIndexesOpt.getOrElse(pageFirstRowIndexes).toSeq,
rangesOpt), expectedValues.map(i => Row(i)), batchSize)
}
/**
* Write a single data page using repetition levels, definition levels and values provided.
*
* Note that this requires `repetitionLevels`, `definitionLevels` and `values` to have the same
* number of elements. For null values, the corresponding slots in `values` will be skipped.
*/
private def writeDataPage(
columnDesc: ColumnDescriptor,
pageWriteStore: PageWriteStore,
repetitionLevels: Seq[Int],
definitionLevels: Seq[Int],
values: Seq[Any],
maxDefinitionLevel: Int,
dictionaryEnabled: Boolean = false): Unit = {
val columnWriterStore = new ColumnWriteStoreV1(pageWriteStore,
ParquetProperties.builder()
.withPageSize(4096)
.withDictionaryEncoding(dictionaryEnabled)
.build())
val columnWriter = columnWriterStore.getColumnWriter(columnDesc)
repetitionLevels.zip(definitionLevels).zipWithIndex.foreach { case ((rl, dl), i) =>
if (dl < maxDefinitionLevel) {
columnWriter.writeNull(rl, dl)
} else {
columnDesc.getPrimitiveType.getPrimitiveTypeName match {
case PrimitiveTypeName.INT32 =>
columnWriter.write(values(i).asInstanceOf[Int], rl, dl)
case PrimitiveTypeName.INT64 =>
columnWriter.write(values(i).asInstanceOf[Long], rl, dl)
case PrimitiveTypeName.BOOLEAN =>
columnWriter.write(values(i).asInstanceOf[Boolean], rl, dl)
case PrimitiveTypeName.FLOAT =>
columnWriter.write(values(i).asInstanceOf[Float], rl, dl)
case PrimitiveTypeName.DOUBLE =>
columnWriter.write(values(i).asInstanceOf[Double], rl, dl)
case PrimitiveTypeName.BINARY =>
columnWriter.write(Binary.fromString(values(i).asInstanceOf[String]), rl, dl)
case _ =>
throw new IllegalStateException(s"Unexpected type: " +
s"${columnDesc.getPrimitiveType.getPrimitiveTypeName}")
}
}
columnWriterStore.endRecord()
}
columnWriterStore.flush()
}
private def checkAnswer(
totalRowCount: Int,
fileSchema: MessageType,
readStore: PageReadStore,
expected: Seq[Row],
batchSize: Int = NUM_VALUES): Unit = {
import collection.JavaConverters._
val recordReader = new VectorizedParquetRecordReader(
DateTimeUtils.getZoneId("EST"), "CORRECTED", "CORRECTED", true, batchSize)
recordReader.initialize(fileSchema, fileSchema,
TestParquetRowGroupReader(Seq(readStore)), totalRowCount)
// convert both actual and expected rows into collections
val schema = recordReader.sparkSchema
val expectedRowIt = ColumnVectorUtils.toBatch(
schema, MemoryMode.ON_HEAP, expected.iterator.asJava).rowIterator()
val rowOrdering = RowOrdering.createNaturalAscendingOrdering(schema.map(_.dataType))
var i = 0
while (expectedRowIt.hasNext && recordReader.nextKeyValue()) {
val expectedRow = expectedRowIt.next()
val actualRow = recordReader.getCurrentValue.asInstanceOf[InternalRow]
assert(rowOrdering.compare(expectedRow, actualRow) == 0, {
val expectedRowStr = toDebugString(schema, expectedRow)
val actualRowStr = toDebugString(schema, actualRow)
s"at index $i, expected row: $expectedRowStr doesn't match actual row: $actualRowStr"
})
i += 1
}
}
private def toDebugString(schema: StructType, row: InternalRow): String = {
if (row == null) "null"
else {
val fieldStrings = schema.fields.zipWithIndex.map { case (f, i) =>
f.dataType match {
case IntegerType =>
row.getInt(i).toString
case StringType =>
val utf8Str = row.getUTF8String(i)
if (utf8Str == null) "null"
else utf8Str.toString
case ArrayType(_, _) =>
val elements = row.getArray(i)
if (elements == null) "null"
else elements.array.mkString("[", ", ", "]")
case _ =>
throw new IllegalArgumentException(s"Unsupported data type: ${f.dataType}")
}
}
fieldStrings.mkString(", ")
}
}
case class TestParquetRowGroupReader(groups: Seq[PageReadStore]) extends ParquetRowGroupReader {
private var index: Int = 0
override def readNextRowGroup(): PageReadStore = {
if (index == groups.length) {
null
} else {
val res = groups(index)
index += 1
res
}
}
override def close(): Unit = {}
}
private case class TestPageReadStore(
wrapped: PageReadStore,
firstRowIndexes: Seq[Long],
rowIndexRangesOpt: Option[Seq[(Long, Long)]] = None) extends PageReadStore {
override def getPageReader(descriptor: ColumnDescriptor): PageReader = {
val originalReader = wrapped.getPageReader(descriptor)
TestPageReader(originalReader, firstRowIndexes)
}
override def getRowCount: Long = wrapped.getRowCount
override def getRowIndexes: Optional[PrimitiveIterator.OfLong] = {
rowIndexRangesOpt.map { ranges =>
Optional.of(new PrimitiveIterator.OfLong {
private var currentRangeIdx: Int = 0
private var currentRowIdx: Long = -1
override def nextLong(): Long = {
if (!hasNext) throw new NoSuchElementException("No more element")
val res = currentRowIdx
currentRowIdx += 1
res
}
override def hasNext: Boolean = {
while (currentRangeIdx < ranges.length) {
if (currentRowIdx > ranges(currentRangeIdx)._2) {
// we've exhausted the current range - move to the next range
currentRangeIdx += 1
currentRowIdx = -1
} else {
if (currentRowIdx == -1) {
currentRowIdx = ranges(currentRangeIdx)._1
}
return true
}
}
false
}
})
}.getOrElse(Optional.empty())
}
}
private case class TestPageReader(
wrapped: PageReader,
firstRowIndexes: Seq[Long]) extends PageReader {
private var index = 0
override def readDictionaryPage(): DictionaryPage = wrapped.readDictionaryPage()
override def getTotalValueCount: Long = wrapped.getTotalValueCount
override def readPage(): DataPage = {
val wrappedPage = try {
wrapped.readPage()
} catch {
case _: ParquetDecodingException =>
null
}
if (wrappedPage == null) {
wrappedPage
} else {
val res = new TestDataPage(wrappedPage, firstRowIndexes(index))
index += 1
res
}
}
}
}