[SPARK-15593][SQL] Add DataFrameWriter.foreach to allow the user consuming data in ContinuousQuery

## What changes were proposed in this pull request?

* Add DataFrameWriter.foreach to allow the user consuming data in ContinuousQuery
  * ForeachWriter is the interface for the user to consume partitions of data
* Add a type parameter T to DataFrameWriter

Usage
```Scala
val ds = spark.read....stream().as[String]
ds.....write
         .queryName(...)
        .option("checkpointLocation", ...)
        .foreach(new ForeachWriter[Int] {
          def open(partitionId: Long, version: Long): Boolean = {
             // prepare some resources for a partition
             // check `version` if possible and return `false` if this is a duplicated data to skip the data processing.
          }

          override def process(value: Int): Unit = {
              // process data
          }

          def close(errorOrNull: Throwable): Unit = {
             // release resources for a partition
             // check `errorOrNull` and handle the error if necessary.
          }
        })
```

## How was this patch tested?

New unit tests.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #13342 from zsxwing/foreach.
This commit is contained in:
Shixiong Zhu 2016-06-10 00:11:46 -07:00 committed by Tathagata Das
parent 5a3533e779
commit 00c310133d
6 changed files with 413 additions and 42 deletions

View file

@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource, HadoopFsRelation}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{ContinuousQuery, OutputMode, ProcessingTime, Trigger}
import org.apache.spark.util.Utils
@ -40,7 +40,9 @@ import org.apache.spark.util.Utils
*
* @since 1.4.0
*/
final class DataFrameWriter private[sql](df: DataFrame) {
final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private val df = ds.toDF()
/**
* Specifies the behavior when data or table already exists. Options include:
@ -51,7 +53,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
def mode(saveMode: SaveMode): DataFrameWriter = {
def mode(saveMode: SaveMode): DataFrameWriter[T] = {
// mode() is used for non-continuous queries
// outputMode() is used for continuous queries
assertNotStreaming("mode() can only be called on non-continuous queries")
@ -68,7 +70,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
def mode(saveMode: String): DataFrameWriter = {
def mode(saveMode: String): DataFrameWriter[T] = {
// mode() is used for non-continuous queries
// outputMode() is used for continuous queries
assertNotStreaming("mode() can only be called on non-continuous queries")
@ -93,7 +95,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
def outputMode(outputMode: OutputMode): DataFrameWriter = {
def outputMode(outputMode: OutputMode): DataFrameWriter[T] = {
assertStreaming("outputMode() can only be called on continuous queries")
this.outputMode = outputMode
this
@ -109,7 +111,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
def outputMode(outputMode: String): DataFrameWriter = {
def outputMode(outputMode: String): DataFrameWriter[T] = {
assertStreaming("outputMode() can only be called on continuous queries")
this.outputMode = outputMode.toLowerCase match {
case "append" =>
@ -147,7 +149,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
def trigger(trigger: Trigger): DataFrameWriter = {
def trigger(trigger: Trigger): DataFrameWriter[T] = {
assertStreaming("trigger() can only be called on continuous queries")
this.trigger = trigger
this
@ -158,7 +160,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
def format(source: String): DataFrameWriter = {
def format(source: String): DataFrameWriter[T] = {
this.source = source
this
}
@ -168,7 +170,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
def option(key: String, value: String): DataFrameWriter = {
def option(key: String, value: String): DataFrameWriter[T] = {
this.extraOptions += (key -> value)
this
}
@ -178,28 +180,28 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 2.0.0
*/
def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString)
def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, value.toString)
/**
* Adds an output option for the underlying data source.
*
* @since 2.0.0
*/
def option(key: String, value: Long): DataFrameWriter = option(key, value.toString)
def option(key: String, value: Long): DataFrameWriter[T] = option(key, value.toString)
/**
* Adds an output option for the underlying data source.
*
* @since 2.0.0
*/
def option(key: String, value: Double): DataFrameWriter = option(key, value.toString)
def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString)
/**
* (Scala-specific) Adds output options for the underlying data source.
*
* @since 1.4.0
*/
def options(options: scala.collection.Map[String, String]): DataFrameWriter = {
def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = {
this.extraOptions ++= options
this
}
@ -209,7 +211,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
def options(options: java.util.Map[String, String]): DataFrameWriter = {
def options(options: java.util.Map[String, String]): DataFrameWriter[T] = {
this.options(options.asScala)
this
}
@ -232,7 +234,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
@scala.annotation.varargs
def partitionBy(colNames: String*): DataFrameWriter = {
def partitionBy(colNames: String*): DataFrameWriter[T] = {
this.partitioningColumns = Option(colNames)
this
}
@ -246,7 +248,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0
*/
@scala.annotation.varargs
def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = {
def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter[T] = {
this.numBuckets = Option(numBuckets)
this.bucketColumnNames = Option(colName +: colNames)
this
@ -260,7 +262,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0
*/
@scala.annotation.varargs
def sortBy(colName: String, colNames: String*): DataFrameWriter = {
def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = {
this.sortColumnNames = Option(colName +: colNames)
this
}
@ -301,7 +303,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
def queryName(queryName: String): DataFrameWriter = {
def queryName(queryName: String): DataFrameWriter[T] = {
assertStreaming("queryName() can only be called on continuous queries")
this.extraOptions += ("queryName" -> queryName)
this
@ -337,16 +339,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
val queryName =
extraOptions.getOrElse(
"queryName", throw new AnalysisException("queryName must be specified for memory sink"))
val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
new Path(userSpecified).toUri.toString
}.orElse {
val checkpointConfig: Option[String] =
df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION)
checkpointConfig.map { location =>
new Path(location, queryName).toUri.toString
}
}.getOrElse {
val checkpointLocation = getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath
}
@ -378,21 +371,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))
val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
val checkpointLocation = extraOptions.get("checkpointLocation")
.orElse {
df.sparkSession.sessionState.conf.checkpointLocation.map { l =>
new Path(l, queryName).toUri.toString
}
}.getOrElse {
throw new AnalysisException("checkpointLocation must be specified either " +
"through option() or SQLConf")
}
df.sparkSession.sessionState.continuousQueryManager.startQuery(
queryName,
checkpointLocation,
getCheckpointLocation(queryName, failIfNotSet = true).get,
df,
dataSource.createSink(outputMode),
outputMode,
@ -400,6 +382,94 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
}
/**
* :: Experimental ::
* Starts the execution of the streaming query, which will continually send results to the given
* [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data
* generated by the [[DataFrame]]/[[Dataset]] to an external system. The returned The returned
* [[ContinuousQuery]] object can be used to interact with the stream.
*
* Scala example:
* {{{
* datasetOfString.write.foreach(new ForeachWriter[String] {
*
* def open(partitionId: Long, version: Long): Boolean = {
* // open connection
* }
*
* def process(record: String) = {
* // write string to connection
* }
*
* def close(errorOrNull: Throwable): Unit = {
* // close the connection
* }
* })
* }}}
*
* Java example:
* {{{
* datasetOfString.write().foreach(new ForeachWriter<String>() {
*
* @Override
* public boolean open(long partitionId, long version) {
* // open connection
* }
*
* @Override
* public void process(String value) {
* // write string to connection
* }
*
* @Override
* public void close(Throwable errorOrNull) {
* // close the connection
* }
* });
* }}}
*
* @since 2.0.0
*/
@Experimental
def foreach(writer: ForeachWriter[T]): ContinuousQuery = {
assertNotBucketed("foreach")
assertStreaming(
"foreach() can only be called on streaming Datasets/DataFrames.")
val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
val sink = new ForeachSink[T](ds.sparkSession.sparkContext.clean(writer))(ds.exprEnc)
df.sparkSession.sessionState.continuousQueryManager.startQuery(
queryName,
getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
Utils.createTempDir(namePrefix = "foreach.stream").getCanonicalPath
},
df,
sink,
outputMode,
trigger)
}
/**
* Returns the checkpointLocation for a query. If `failIfNotSet` is `true` but the checkpoint
* location is not set, [[AnalysisException]] will be thrown. If `failIfNotSet` is `false`, `None`
* will be returned if the checkpoint location is not set.
*/
private def getCheckpointLocation(queryName: String, failIfNotSet: Boolean): Option[String] = {
val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
new Path(userSpecified).toUri.toString
}.orElse {
df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION).map { location =>
new Path(location, queryName).toUri.toString
}
}
if (failIfNotSet && checkpointLocation.isEmpty) {
throw new AnalysisException("checkpointLocation must be specified either " +
"""through option("checkpointLocation", ...) or """ +
s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""")
}
checkpointLocation
}
/**
* Inserts the content of the [[DataFrame]] to the specified table. It requires that
* the schema of the [[DataFrame]] is the same as the schema of the table.

View file

@ -2400,7 +2400,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
def write: DataFrameWriter = new DataFrameWriter(toDF())
def write: DataFrameWriter[T] = new DataFrameWriter[T](this)
/**
* Returns the content of the Dataset as a Dataset of JSON strings.

View file

@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.streaming.ContinuousQuery
/**
* :: Experimental ::
* A class to consume data generated by a [[ContinuousQuery]]. Typically this is used to send the
* generated data to external systems. Each partition will use a new deserialized instance, so you
* usually should do all the initialization (e.g. opening a connection or initiating a transaction)
* in the `open` method.
*
* Scala example:
* {{{
* datasetOfString.write.foreach(new ForeachWriter[String] {
*
* def open(partitionId: Long, version: Long): Boolean = {
* // open connection
* }
*
* def process(record: String) = {
* // write string to connection
* }
*
* def close(errorOrNull: Throwable): Unit = {
* // close the connection
* }
* })
* }}}
*
* Java example:
* {{{
* datasetOfString.write().foreach(new ForeachWriter<String>() {
*
* @Override
* public boolean open(long partitionId, long version) {
* // open connection
* }
*
* @Override
* public void process(String value) {
* // write string to connection
* }
*
* @Override
* public void close(Throwable errorOrNull) {
* // close the connection
* }
* });
* }}}
* @since 2.0.0
*/
@Experimental
abstract class ForeachWriter[T] extends Serializable {
/**
* Called when starting to process one partition of new data in the executor. The `version` is
* for data deduplication when there are failures. When recovering from a failure, some data may
* be generated multiple times but they will always have the same version.
*
* If this method finds using the `partitionId` and `version` that this partition has already been
* processed, it can return `false` to skip the further data processing. However, `close` still
* will be called for cleaning up resources.
*
* @param partitionId the partition id.
* @param version a unique id for data deduplication.
* @return `true` if the corresponding partition and version id should be processed. `false`
* indicates the partition should be skipped.
*/
def open(partitionId: Long, version: Long): Boolean
/**
* Called to process the data in the executor side. This method will be called only when `open`
* returns `true`.
*/
def process(value: T): Unit
/**
* Called when stopping to process one partition of new data in the executor side. This is
* guaranteed to be called either `open` returns `true` or `false`. However,
* `close` won't be called in the following cases:
* - JVM crashes without throwing a `Throwable`
* - `open` throws a `Throwable`.
*
* @param errorOrNull the error thrown during processing data or null if there was no error.
*/
def close(errorOrNull: Throwable): Unit
}

View file

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import org.apache.spark.TaskContext
import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
/**
* A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
* [[ForeachWriter]].
*
* @param writer The [[ForeachWriter]] to process all data.
* @tparam T The expected type of the sink.
*/
class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {
override def addBatch(batchId: Long, data: DataFrame): Unit = {
data.as[T].foreachPartition { iter =>
if (writer.open(TaskContext.getPartitionId(), batchId)) {
var isFailed = false
try {
while (iter.hasNext) {
writer.process(iter.next())
}
} catch {
case e: Throwable =>
isFailed = true
writer.close(e)
}
if (!isFailed) {
writer.close(null)
}
} else {
writer.close(null)
}
}
}
}

View file

@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import java.util.concurrent.ConcurrentLinkedQueue
import scala.collection.mutable
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.test.SharedSQLContext
class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._
after {
sqlContext.streams.active.foreach(_.stop())
}
test("foreach") {
withTempDir { checkpointDir =>
val input = MemoryStream[Int]
val query = input.toDS().repartition(2).write
.option("checkpointLocation", checkpointDir.getCanonicalPath)
.foreach(new TestForeachWriter())
input.addData(1, 2, 3, 4)
query.processAllAvailable()
val expectedEventsForPartition0 = Seq(
ForeachSinkSuite.Open(partition = 0, version = 0),
ForeachSinkSuite.Process(value = 1),
ForeachSinkSuite.Process(value = 3),
ForeachSinkSuite.Close(None)
)
val expectedEventsForPartition1 = Seq(
ForeachSinkSuite.Open(partition = 1, version = 0),
ForeachSinkSuite.Process(value = 2),
ForeachSinkSuite.Process(value = 4),
ForeachSinkSuite.Close(None)
)
val allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 2)
assert {
allEvents === Seq(expectedEventsForPartition0, expectedEventsForPartition1) ||
allEvents === Seq(expectedEventsForPartition1, expectedEventsForPartition0)
}
query.stop()
}
}
test("foreach with error") {
withTempDir { checkpointDir =>
val input = MemoryStream[Int]
val query = input.toDS().repartition(1).write
.option("checkpointLocation", checkpointDir.getCanonicalPath)
.foreach(new TestForeachWriter() {
override def process(value: Int): Unit = {
super.process(value)
throw new RuntimeException("error")
}
})
input.addData(1, 2, 3, 4)
query.processAllAvailable()
val allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 1)
assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0))
assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1))
val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close]
assert(errorEvent.error.get.isInstanceOf[RuntimeException])
assert(errorEvent.error.get.getMessage === "error")
query.stop()
}
}
}
/** A global object to collect events in the executor */
object ForeachSinkSuite {
trait Event
case class Open(partition: Long, version: Long) extends Event
case class Process[T](value: T) extends Event
case class Close(error: Option[Throwable]) extends Event
private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]()
def addEvents(events: Seq[Event]): Unit = {
_allEvents.add(events)
}
def allEvents(): Seq[Seq[Event]] = {
_allEvents.toArray(new Array[Seq[Event]](_allEvents.size()))
}
def clear(): Unit = {
_allEvents.clear()
}
}
/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */
class TestForeachWriter extends ForeachWriter[Int] {
ForeachSinkSuite.clear()
private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]()
override def open(partitionId: Long, version: Long): Boolean = {
events += ForeachSinkSuite.Open(partition = partitionId, version = version)
true
}
override def process(value: Int): Unit = {
events += ForeachSinkSuite.Process(value)
}
override def close(errorOrNull: Throwable): Unit = {
events += ForeachSinkSuite.Close(error = Option(errorOrNull))
ForeachSinkSuite.addEvents(events)
}
}

View file

@ -238,7 +238,9 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
shuffleLeft: Boolean,
shuffleRight: Boolean): Unit = {
withTable("bucketed_table1", "bucketed_table2") {
def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = {
def withBucket(
writer: DataFrameWriter[Row],
bucketSpec: Option[BucketSpec]): DataFrameWriter[Row] = {
bucketSpec.map { spec =>
writer.bucketBy(
spec.numBuckets,