[SPARK-22666][ML][SQL] Spark datasource for image format
## What changes were proposed in this pull request? Implement an image schema datasource. This image datasource support: - partition discovery (loading partitioned images) - dropImageFailures (the same behavior with `ImageSchema.readImage`) - path wildcard matching (the same behavior with `ImageSchema.readImage`) - loading recursively from directory (different from `ImageSchema.readImage`, but use such path: `/path/to/dir/**`) This datasource **NOT** support: - specify `numPartitions` (it will be determined by datasource automatically) - sampling (you can use `df.sample` later but the sampling operator won't be pushdown to datasource) ## How was this patch tested? Unit tests. ## Benchmark I benchmark and compare the cost time between old `ImageSchema.read` API and my image datasource. **cluster**: 4 nodes, each with 64GB memory, 8 cores CPU **test dataset**: Flickr8k_Dataset (about 8091 images) **time cost**: - My image datasource time (automatically generate 258 partitions): 38.04s - `ImageSchema.read` time (set 16 partitions): 68.4s - `ImageSchema.read` time (set 258 partitions): 90.6s **time cost when increase image number by double (clone Flickr8k_Dataset and loads double number images)**: - My image datasource time (automatically generate 515 partitions): 95.4s - `ImageSchema.read` (set 32 partitions): 109s - `ImageSchema.read` (set 515 partitions): 105s So we can see that my image datasource implementation (this PR) bring some performance improvement compared against old`ImageSchema.read` API. Closes #22328 from WeichenXu123/image_datasource. Authored-by: WeichenXu <weichen.xu@databricks.com> Signed-off-by: Xiangrui Meng <meng@databricks.com>
Before Width: | Height: | Size: 27 KiB After Width: | Height: | Size: 27 KiB |
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 35 KiB |
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 26 KiB |
Before Width: | Height: | Size: 30 KiB After Width: | Height: | Size: 30 KiB |
13
data/mllib/images/origin/license.txt
Normal file
|
@ -0,0 +1,13 @@
|
|||
The images in the folder "kittens" are under the creative commons CC0 license, or no rights reserved:
|
||||
https://creativecommons.org/share-your-work/public-domain/cc0/
|
||||
The images are taken from:
|
||||
https://ccsearch.creativecommons.org/image/detail/WZnbJSJ2-dzIDiuUUdto3Q==
|
||||
https://ccsearch.creativecommons.org/image/detail/_TlKu_rm_QrWlR0zthQTXA==
|
||||
https://ccsearch.creativecommons.org/image/detail/OPNnHJb6q37rSZ5o_L5JHQ==
|
||||
https://ccsearch.creativecommons.org/image/detail/B2CVP_j5KjwZm7UAVJ3Hvw==
|
||||
|
||||
The chr30.4.184.jpg and grayscale.jpg images are also under the CC0 license, taken from:
|
||||
https://ccsearch.creativecommons.org/image/detail/8eO_qqotBfEm2UYxirLntw==
|
||||
|
||||
The image under "multi-channel" directory is under the CC BY-SA 4.0 license cropped from:
|
||||
https://en.wikipedia.org/wiki/Alpha_compositing#/media/File:Hue_alpha_falloff.png
|
Before Width: | Height: | Size: 683 B After Width: | Height: | Size: 683 B |
Before Width: | Height: | Size: 747 B After Width: | Height: | Size: 747 B |
Before Width: | Height: | Size: 58 KiB After Width: | Height: | Size: 58 KiB |
Before Width: | Height: | Size: 36 KiB After Width: | Height: | Size: 36 KiB |
After Width: | Height: | Size: 27 KiB |
|
@ -0,0 +1 @@
|
|||
not an image
|
BIN
data/mllib/images/partitioned/cls=kittens/date=2018-02/54893.jpg
Normal file
After Width: | Height: | Size: 35 KiB |
After Width: | Height: | Size: 26 KiB |
After Width: | Height: | Size: 30 KiB |
After Width: | Height: | Size: 683 B |
After Width: | Height: | Size: 747 B |
After Width: | Height: | Size: 58 KiB |
After Width: | Height: | Size: 36 KiB |
|
@ -1 +1,2 @@
|
|||
org.apache.spark.ml.source.libsvm.LibSVMFileFormat
|
||||
org.apache.spark.ml.source.image.ImageFileFormat
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.source.image
|
||||
|
||||
/**
|
||||
* `image` package implements Spark SQL data source API for loading image data as `DataFrame`.
|
||||
* The loaded `DataFrame` has one `StructType` column: `image`.
|
||||
* The schema of the `image` column is:
|
||||
* - origin: String (represents the file path of the image)
|
||||
* - height: Int (height of the image)
|
||||
* - width: Int (width of the image)
|
||||
* - nChannels: Int (number of the image channels)
|
||||
* - mode: Int (OpenCV-compatible type)
|
||||
* - data: BinaryType (Image bytes in OpenCV-compatible order: row-wise BGR in most cases)
|
||||
*
|
||||
* To use image data source, you need to set "image" as the format in `DataFrameReader` and
|
||||
* optionally specify the data source options, for example:
|
||||
* {{{
|
||||
* // Scala
|
||||
* val df = spark.read.format("image")
|
||||
* .option("dropInvalid", true)
|
||||
* .load("data/mllib/images/partitioned")
|
||||
*
|
||||
* // Java
|
||||
* Dataset<Row> df = spark.read().format("image")
|
||||
* .option("dropInvalid", true)
|
||||
* .load("data/mllib/images/partitioned");
|
||||
* }}}
|
||||
*
|
||||
* Image data source supports the following options:
|
||||
* - "dropInvalid": Whether to drop the files that are not valid images from the result.
|
||||
*
|
||||
* @note This IMAGE data source does not support saving images to files.
|
||||
*
|
||||
* @note This class is public for documentation purpose. Please don't use this class directly.
|
||||
* Rather, use the data source API as illustrated above.
|
||||
*/
|
||||
class ImageDataSource private() {}
|
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.source.image
|
||||
|
||||
import com.google.common.io.{ByteStreams, Closeables}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileStatus, Path}
|
||||
import org.apache.hadoop.mapreduce.Job
|
||||
|
||||
import org.apache.spark.ml.image.ImageSchema
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeRow}
|
||||
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
||||
import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, OutputWriterFactory, PartitionedFile}
|
||||
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.util.SerializableConfiguration
|
||||
|
||||
private[image] class ImageFileFormat extends FileFormat with DataSourceRegister {
|
||||
|
||||
override def inferSchema(
|
||||
sparkSession: SparkSession,
|
||||
options: Map[String, String],
|
||||
files: Seq[FileStatus]): Option[StructType] = Some(ImageSchema.imageSchema)
|
||||
|
||||
override def prepareWrite(
|
||||
sparkSession: SparkSession,
|
||||
job: Job,
|
||||
options: Map[String, String],
|
||||
dataSchema: StructType): OutputWriterFactory = {
|
||||
throw new UnsupportedOperationException("Write is not supported for image data source")
|
||||
}
|
||||
|
||||
override def shortName(): String = "image"
|
||||
|
||||
override protected def buildReader(
|
||||
sparkSession: SparkSession,
|
||||
dataSchema: StructType,
|
||||
partitionSchema: StructType,
|
||||
requiredSchema: StructType,
|
||||
filters: Seq[Filter],
|
||||
options: Map[String, String],
|
||||
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
|
||||
assert(
|
||||
requiredSchema.length <= 1,
|
||||
"Image data source only produces a single data column named \"image\".")
|
||||
|
||||
val broadcastedHadoopConf =
|
||||
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
|
||||
|
||||
val imageSourceOptions = new ImageOptions(options)
|
||||
|
||||
(file: PartitionedFile) => {
|
||||
val emptyUnsafeRow = new UnsafeRow(0)
|
||||
if (!imageSourceOptions.dropInvalid && requiredSchema.isEmpty) {
|
||||
Iterator(emptyUnsafeRow)
|
||||
} else {
|
||||
val origin = file.filePath
|
||||
val path = new Path(origin)
|
||||
val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
|
||||
val stream = fs.open(path)
|
||||
val bytes = try {
|
||||
ByteStreams.toByteArray(stream)
|
||||
} finally {
|
||||
Closeables.close(stream, true)
|
||||
}
|
||||
val resultOpt = ImageSchema.decode(origin, bytes)
|
||||
val filteredResult = if (imageSourceOptions.dropInvalid) {
|
||||
resultOpt.toIterator
|
||||
} else {
|
||||
Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin)))
|
||||
}
|
||||
|
||||
if (requiredSchema.isEmpty) {
|
||||
filteredResult.map(_ => emptyUnsafeRow)
|
||||
} else {
|
||||
val converter = RowEncoder(requiredSchema)
|
||||
filteredResult.map(row => converter.toRow(row))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.source.image
|
||||
|
||||
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
||||
|
||||
private[image] class ImageOptions(
|
||||
@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
|
||||
|
||||
def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
|
||||
|
||||
/**
|
||||
* Whether to drop invalid images. If true, invalid images will be removed, otherwise
|
||||
* invalid images will be returned with empty data and all other field filled with `-1`.
|
||||
*/
|
||||
val dropInvalid = parameters.getOrElse("dropInvalid", "false").toBoolean
|
||||
}
|
|
@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
|
|||
|
||||
class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
// Single column of images named "image"
|
||||
private lazy val imagePath = "../data/mllib/images"
|
||||
private lazy val imagePath = "../data/mllib/images/origin"
|
||||
|
||||
test("Smoke test: create basic ImageSchema dataframe") {
|
||||
val origin = "path"
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.source.image
|
||||
|
||||
import java.nio.file.Paths
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.image.ImageSchema._
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.functions.{col, substring_index}
|
||||
|
||||
class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
// Single column of images named "image"
|
||||
private lazy val imagePath = "../data/mllib/images/partitioned"
|
||||
|
||||
test("image datasource count test") {
|
||||
val df1 = spark.read.format("image").load(imagePath)
|
||||
assert(df1.count === 9)
|
||||
|
||||
val df2 = spark.read.format("image").option("dropInvalid", true).load(imagePath)
|
||||
assert(df2.count === 8)
|
||||
}
|
||||
|
||||
test("image datasource test: read jpg image") {
|
||||
val df = spark.read.format("image").load(imagePath + "/cls=kittens/date=2018-02/DP153539.jpg")
|
||||
assert(df.count() === 1)
|
||||
}
|
||||
|
||||
test("image datasource test: read png image") {
|
||||
val df = spark.read.format("image").load(imagePath + "/cls=multichannel/date=2018-01/BGRA.png")
|
||||
assert(df.count() === 1)
|
||||
}
|
||||
|
||||
test("image datasource test: read non image") {
|
||||
val filePath = imagePath + "/cls=kittens/date=2018-01/not-image.txt"
|
||||
val df = spark.read.format("image").option("dropInvalid", true)
|
||||
.load(filePath)
|
||||
assert(df.count() === 0)
|
||||
|
||||
val df2 = spark.read.format("image").option("dropInvalid", false)
|
||||
.load(filePath)
|
||||
assert(df2.count() === 1)
|
||||
val result = df2.head()
|
||||
assert(result === invalidImageRow(
|
||||
Paths.get(filePath).toAbsolutePath().normalize().toUri().toString))
|
||||
}
|
||||
|
||||
test("image datasource partition test") {
|
||||
val result = spark.read.format("image")
|
||||
.option("dropInvalid", true).load(imagePath)
|
||||
.select(substring_index(col("image.origin"), "/", -1).as("origin"), col("cls"), col("date"))
|
||||
.collect()
|
||||
|
||||
assert(Set(result: _*) === Set(
|
||||
Row("29.5.a_b_EGDP022204.jpg", "kittens", "2018-01"),
|
||||
Row("54893.jpg", "kittens", "2018-02"),
|
||||
Row("DP153539.jpg", "kittens", "2018-02"),
|
||||
Row("DP802813.jpg", "kittens", "2018-02"),
|
||||
Row("BGRA.png", "multichannel", "2018-01"),
|
||||
Row("BGRA_alpha_60.png", "multichannel", "2018-01"),
|
||||
Row("chr30.4.184.jpg", "multichannel", "2018-02"),
|
||||
Row("grayscale.jpg", "multichannel", "2018-02")
|
||||
))
|
||||
}
|
||||
|
||||
// Images with the different number of channels
|
||||
test("readImages pixel values test") {
|
||||
val images = spark.read.format("image").option("dropInvalid", true)
|
||||
.load(imagePath + "/cls=multichannel/").collect()
|
||||
|
||||
val firstBytes20Set = images.map { rrow =>
|
||||
val row = rrow.getAs[Row]("image")
|
||||
val filename = Paths.get(getOrigin(row)).getFileName().toString()
|
||||
val mode = getMode(row)
|
||||
val bytes20 = getData(row).slice(0, 20).toList
|
||||
filename -> Tuple2(mode, bytes20) // Cannot remove `Tuple2`, otherwise `->` operator
|
||||
// will match 2 arguments
|
||||
}.toSet
|
||||
|
||||
assert(firstBytes20Set === expectedFirstBytes20Set)
|
||||
}
|
||||
|
||||
// number of channels and first 20 bytes of OpenCV representation
|
||||
// - default representation for 3-channel RGB images is BGR row-wise:
|
||||
// (B00, G00, R00, B10, G10, R10, ...)
|
||||
// - default representation for 4-channel RGB images is BGRA row-wise:
|
||||
// (B00, G00, R00, A00, B10, G10, R10, A10, ...)
|
||||
private val expectedFirstBytes20Set = Set(
|
||||
"grayscale.jpg" ->
|
||||
((0, List[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62,
|
||||
-57, -60, -63, -53, -49, -55, -69))),
|
||||
"chr30.4.184.jpg" -> ((16,
|
||||
List[Byte](-9, -3, -1, -43, -32, -28, -75, -60, -57, -78, -59, -56, -74, -59, -57,
|
||||
-71, -58, -56, -73, -64))),
|
||||
"BGRA.png" -> ((24,
|
||||
List[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128,
|
||||
-128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))),
|
||||
"BGRA_alpha_60.png" -> ((24,
|
||||
List[Byte](-128, -128, -8, 60, -128, -128, -8, 60, -128,
|
||||
-128, -8, 60, 127, 127, -9, 60, 127, 127, -9, 60)))
|
||||
)
|
||||
}
|
|
@ -216,7 +216,7 @@ class _ImageSchema(object):
|
|||
:return: a :class:`DataFrame` with a single column of "images",
|
||||
see ImageSchema for details.
|
||||
|
||||
>>> df = ImageSchema.readImages('data/mllib/images/kittens', recursive=True)
|
||||
>>> df = ImageSchema.readImages('data/mllib/images/origin/kittens', recursive=True)
|
||||
>>> df.count()
|
||||
5
|
||||
|
||||
|
|
|
@ -2186,7 +2186,7 @@ class FPGrowthTests(SparkSessionTestCase):
|
|||
class ImageReaderTest(SparkSessionTestCase):
|
||||
|
||||
def test_read_images(self):
|
||||
data_path = 'data/mllib/images/kittens'
|
||||
data_path = 'data/mllib/images/origin/kittens'
|
||||
df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
|
||||
self.assertEqual(df.count(), 4)
|
||||
first_row = df.take(1)[0][0]
|
||||
|
@ -2253,7 +2253,7 @@ class ImageReaderTest2(PySparkTestCase):
|
|||
def test_read_images_multiple_times(self):
|
||||
# This test case is to check if `ImageSchema.readImages` tries to
|
||||
# initiate Hive client multiple times. See SPARK-22651.
|
||||
data_path = 'data/mllib/images/kittens'
|
||||
data_path = 'data/mllib/images/origin/kittens'
|
||||
ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
|
||||
ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
|
||||
|
||||
|
|