[SPARK-27534][SQL] Do not load content column in binary data source if it is not selected

## What changes were proposed in this pull request?

A follow-up task from SPARK-25348. To save I/O cost, Spark shouldn't attempt to read the file if users didn't request the `content` column. For example:
```
spark.read.format("binaryFile").load(path).filter($"length" < 1000000).count()
```

## How was this patch tested?

Unit test added.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Closes #24473 from WeichenXu123/SPARK-27534.

Lead-authored-by: Xiangrui Meng <meng@databricks.com>
Co-authored-by: WeichenXu <weichen.xu@databricks.com>
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
Xiangrui Meng 2019-04-28 07:57:03 -07:00
parent d8db7db50b
commit 20a3ef7259
2 changed files with 89 additions and 48 deletions

View file

@ -26,12 +26,10 @@ import org.apache.hadoop.mapreduce.Job
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, GreaterThan,
GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.SerializableConfiguration
@ -80,7 +78,7 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister {
false
}
override def shortName(): String = "binaryFile"
override def shortName(): String = BINARY_FILE
override protected def buildReader(
sparkSession: SparkSession,
@ -90,54 +88,43 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister {
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
require(dataSchema.sameType(schema),
s"""
|Binary file data source expects dataSchema: $schema,
|but got: $dataSchema.
""".stripMargin)
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val binaryFileSourceOptions = new BinaryFileSourceOptions(options)
val pathGlobPattern = binaryFileSourceOptions.pathGlobFilter
val filterFuncs = filters.map(filter => createFilterFunction(filter))
file: PartitionedFile => {
val path = file.filePath
val fsPath = new Path(path)
val path = new Path(file.filePath)
// TODO: Improve performance here: each file will recompile the glob pattern here.
if (pathGlobPattern.forall(new GlobFilter(_).accept(fsPath))) {
val fs = fsPath.getFileSystem(broadcastedHadoopConf.value.value)
val fileStatus = fs.getFileStatus(fsPath)
val length = fileStatus.getLen
val modificationTime = fileStatus.getModificationTime
if (filterFuncs.forall(_.apply(fileStatus))) {
val stream = fs.open(fsPath)
val content = try {
ByteStreams.toByteArray(stream)
if (pathGlobPattern.forall(new GlobFilter(_).accept(path))) {
val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
val status = fs.getFileStatus(path)
if (filterFuncs.forall(_.apply(status))) {
val writer = new UnsafeRowWriter(requiredSchema.length)
writer.resetRowWriter()
requiredSchema.fieldNames.zipWithIndex.foreach {
case (PATH, i) => writer.write(i, UTF8String.fromString(status.getPath.toString))
case (LENGTH, i) => writer.write(i, status.getLen)
case (MODIFICATION_TIME, i) =>
writer.write(i, DateTimeUtils.fromMillis(status.getModificationTime))
case (CONTENT, i) =>
val stream = fs.open(status.getPath)
try {
writer.write(i, ByteStreams.toByteArray(stream))
} finally {
Closeables.close(stream, true)
}
val fullOutput = dataSchema.map { f =>
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
case (other, _) =>
throw new RuntimeException(s"Unsupported field name: ${other}")
}
val requiredOutput = fullOutput.filter { a =>
requiredSchema.fieldNames.contains(a.name)
}
// TODO: Add column pruning
// currently it still read the file content even if content column is not required.
val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
val internalRow = InternalRow(
UTF8String.fromString(path),
DateTimeUtils.fromMillis(modificationTime),
length,
content
)
Iterator(requiredColumns(internalRow))
Iterator.single(writer.getRow)
} else {
Iterator.empty
}
@ -154,6 +141,7 @@ object BinaryFileFormat {
private[binaryfile] val MODIFICATION_TIME = "modificationTime"
private[binaryfile] val LENGTH = "length"
private[binaryfile] val CONTENT = "content"
private[binaryfile] val BINARY_FILE = "binaryFile"
/**
* Schema for the binary file data source.

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.binaryfile
import java.io.File
import java.io.{File, IOException}
import java.nio.file.{Files, StandardOpenOption}
import java.sql.Timestamp
@ -28,6 +28,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path}
import org.mockito.Mockito.{mock, when}
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.sources._
@ -101,7 +102,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest
}
def testBinaryFileDataSource(pathGlobFilter: String): Unit = {
val dfReader = spark.read.format("binaryFile")
val dfReader = spark.read.format(BINARY_FILE)
if (pathGlobFilter != null) {
dfReader.option("pathGlobFilter", pathGlobFilter)
}
@ -124,7 +125,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest
for (fileStatus <- fs.listStatus(dirPath)) {
if (globFilter == null || globFilter.accept(fileStatus.getPath)) {
val fpath = fileStatus.getPath.toString.replace("file:/", "file:///")
val fpath = fileStatus.getPath.toString
val flen = fileStatus.getLen
val modificationTime = new Timestamp(fileStatus.getModificationTime)
@ -157,11 +158,11 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest
}
test("binary file data source do not support write operation") {
val df = spark.read.format("binaryFile").load(testDir)
val df = spark.read.format(BINARY_FILE).load(testDir)
withTempDir { tmpDir =>
val thrown = intercept[UnsupportedOperationException] {
df.write
.format("binaryFile")
.format(BINARY_FILE)
.save(tmpDir + "/test_save")
}
assert(thrown.getMessage.contains("Write is not supported for binary file data source"))
@ -286,4 +287,56 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest
EqualTo(MODIFICATION_TIME, file1Status.getModificationTime)
), true)
}
test("column pruning") {
def getRequiredSchema(fieldNames: String*): StructType = {
StructType(fieldNames.map {
case f if schema.fieldNames.contains(f) => schema(f)
case other => StructField(other, NullType)
})
}
def read(file: File, requiredSchema: StructType): Row = {
val format = new BinaryFileFormat
val reader = format.buildReaderWithPartitionValues(
sparkSession = spark,
dataSchema = schema,
partitionSchema = StructType(Nil),
requiredSchema = requiredSchema,
filters = Seq.empty,
options = Map.empty,
hadoopConf = spark.sessionState.newHadoopConf()
)
val partitionedFile = mock(classOf[PartitionedFile])
when(partitionedFile.filePath).thenReturn(file.getPath)
val encoder = RowEncoder(requiredSchema).resolveAndBind()
encoder.fromRow(reader(partitionedFile).next())
}
val file = new File(Utils.createTempDir(), "data")
val content = "123".getBytes
Files.write(file.toPath, content, StandardOpenOption.CREATE, StandardOpenOption.WRITE)
read(file, getRequiredSchema(MODIFICATION_TIME, CONTENT, LENGTH, PATH)) match {
case Row(t, c, len, p) =>
assert(t === new Timestamp(file.lastModified()))
assert(c === content)
assert(len === content.length)
assert(p.asInstanceOf[String].endsWith(file.getAbsolutePath))
}
file.setReadable(false)
withClue("cannot read content") {
intercept[IOException] {
read(file, getRequiredSchema(CONTENT))
}
}
assert(read(file, getRequiredSchema(LENGTH)) === Row(content.length),
"Get length should not read content.")
intercept[RuntimeException] {
read(file, getRequiredSchema(LENGTH, "other"))
}
val df = spark.read.format(BINARY_FILE).load(file.getPath)
assert(df.count() === 1, "Count should not read content.")
assert(df.select("LENGTH").first().getLong(0) === content.length,
"column pruning should be case insensitive")
}
}