[SPARK-15593][SQL] Add DataFrameWriter.foreach to allow the user consuming data in ContinuousQuery
## What changes were proposed in this pull request? * Add DataFrameWriter.foreach to allow the user consuming data in ContinuousQuery * ForeachWriter is the interface for the user to consume partitions of data * Add a type parameter T to DataFrameWriter Usage ```Scala val ds = spark.read....stream().as[String] ds.....write .queryName(...) .option("checkpointLocation", ...) .foreach(new ForeachWriter[Int] { def open(partitionId: Long, version: Long): Boolean = { // prepare some resources for a partition // check `version` if possible and return `false` if this is a duplicated data to skip the data processing. } override def process(value: Int): Unit = { // process data } def close(errorOrNull: Throwable): Unit = { // release resources for a partition // check `errorOrNull` and handle the error if necessary. } }) ``` ## How was this patch tested? New unit tests. Author: Shixiong Zhu <shixiong@databricks.com> Closes #13342 from zsxwing/foreach.
This commit is contained in:
parent
5a3533e779
commit
00c310133d
|
@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
|||
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
|
||||
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource, HadoopFsRelation}
|
||||
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
|
||||
import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution}
|
||||
import org.apache.spark.sql.execution.streaming._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.streaming.{ContinuousQuery, OutputMode, ProcessingTime, Trigger}
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -40,7 +40,9 @@ import org.apache.spark.util.Utils
|
|||
*
|
||||
* @since 1.4.0
|
||||
*/
|
||||
final class DataFrameWriter private[sql](df: DataFrame) {
|
||||
final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
|
||||
|
||||
private val df = ds.toDF()
|
||||
|
||||
/**
|
||||
* Specifies the behavior when data or table already exists. Options include:
|
||||
|
@ -51,7 +53,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
*
|
||||
* @since 1.4.0
|
||||
*/
|
||||
def mode(saveMode: SaveMode): DataFrameWriter = {
|
||||
def mode(saveMode: SaveMode): DataFrameWriter[T] = {
|
||||
// mode() is used for non-continuous queries
|
||||
// outputMode() is used for continuous queries
|
||||
assertNotStreaming("mode() can only be called on non-continuous queries")
|
||||
|
@ -68,7 +70,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
*
|
||||
* @since 1.4.0
|
||||
*/
|
||||
def mode(saveMode: String): DataFrameWriter = {
|
||||
def mode(saveMode: String): DataFrameWriter[T] = {
|
||||
// mode() is used for non-continuous queries
|
||||
// outputMode() is used for continuous queries
|
||||
assertNotStreaming("mode() can only be called on non-continuous queries")
|
||||
|
@ -93,7 +95,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
* @since 2.0.0
|
||||
*/
|
||||
@Experimental
|
||||
def outputMode(outputMode: OutputMode): DataFrameWriter = {
|
||||
def outputMode(outputMode: OutputMode): DataFrameWriter[T] = {
|
||||
assertStreaming("outputMode() can only be called on continuous queries")
|
||||
this.outputMode = outputMode
|
||||
this
|
||||
|
@ -109,7 +111,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
* @since 2.0.0
|
||||
*/
|
||||
@Experimental
|
||||
def outputMode(outputMode: String): DataFrameWriter = {
|
||||
def outputMode(outputMode: String): DataFrameWriter[T] = {
|
||||
assertStreaming("outputMode() can only be called on continuous queries")
|
||||
this.outputMode = outputMode.toLowerCase match {
|
||||
case "append" =>
|
||||
|
@ -147,7 +149,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
* @since 2.0.0
|
||||
*/
|
||||
@Experimental
|
||||
def trigger(trigger: Trigger): DataFrameWriter = {
|
||||
def trigger(trigger: Trigger): DataFrameWriter[T] = {
|
||||
assertStreaming("trigger() can only be called on continuous queries")
|
||||
this.trigger = trigger
|
||||
this
|
||||
|
@ -158,7 +160,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
*
|
||||
* @since 1.4.0
|
||||
*/
|
||||
def format(source: String): DataFrameWriter = {
|
||||
def format(source: String): DataFrameWriter[T] = {
|
||||
this.source = source
|
||||
this
|
||||
}
|
||||
|
@ -168,7 +170,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
*
|
||||
* @since 1.4.0
|
||||
*/
|
||||
def option(key: String, value: String): DataFrameWriter = {
|
||||
def option(key: String, value: String): DataFrameWriter[T] = {
|
||||
this.extraOptions += (key -> value)
|
||||
this
|
||||
}
|
||||
|
@ -178,28 +180,28 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString)
|
||||
def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, value.toString)
|
||||
|
||||
/**
|
||||
* Adds an output option for the underlying data source.
|
||||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
def option(key: String, value: Long): DataFrameWriter = option(key, value.toString)
|
||||
def option(key: String, value: Long): DataFrameWriter[T] = option(key, value.toString)
|
||||
|
||||
/**
|
||||
* Adds an output option for the underlying data source.
|
||||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
def option(key: String, value: Double): DataFrameWriter = option(key, value.toString)
|
||||
def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString)
|
||||
|
||||
/**
|
||||
* (Scala-specific) Adds output options for the underlying data source.
|
||||
*
|
||||
* @since 1.4.0
|
||||
*/
|
||||
def options(options: scala.collection.Map[String, String]): DataFrameWriter = {
|
||||
def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = {
|
||||
this.extraOptions ++= options
|
||||
this
|
||||
}
|
||||
|
@ -209,7 +211,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
*
|
||||
* @since 1.4.0
|
||||
*/
|
||||
def options(options: java.util.Map[String, String]): DataFrameWriter = {
|
||||
def options(options: java.util.Map[String, String]): DataFrameWriter[T] = {
|
||||
this.options(options.asScala)
|
||||
this
|
||||
}
|
||||
|
@ -232,7 +234,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
* @since 1.4.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def partitionBy(colNames: String*): DataFrameWriter = {
|
||||
def partitionBy(colNames: String*): DataFrameWriter[T] = {
|
||||
this.partitioningColumns = Option(colNames)
|
||||
this
|
||||
}
|
||||
|
@ -246,7 +248,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
* @since 2.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = {
|
||||
def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter[T] = {
|
||||
this.numBuckets = Option(numBuckets)
|
||||
this.bucketColumnNames = Option(colName +: colNames)
|
||||
this
|
||||
|
@ -260,7 +262,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
* @since 2.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def sortBy(colName: String, colNames: String*): DataFrameWriter = {
|
||||
def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = {
|
||||
this.sortColumnNames = Option(colName +: colNames)
|
||||
this
|
||||
}
|
||||
|
@ -301,7 +303,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
* @since 2.0.0
|
||||
*/
|
||||
@Experimental
|
||||
def queryName(queryName: String): DataFrameWriter = {
|
||||
def queryName(queryName: String): DataFrameWriter[T] = {
|
||||
assertStreaming("queryName() can only be called on continuous queries")
|
||||
this.extraOptions += ("queryName" -> queryName)
|
||||
this
|
||||
|
@ -337,16 +339,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
val queryName =
|
||||
extraOptions.getOrElse(
|
||||
"queryName", throw new AnalysisException("queryName must be specified for memory sink"))
|
||||
val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
|
||||
new Path(userSpecified).toUri.toString
|
||||
}.orElse {
|
||||
val checkpointConfig: Option[String] =
|
||||
df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION)
|
||||
|
||||
checkpointConfig.map { location =>
|
||||
new Path(location, queryName).toUri.toString
|
||||
}
|
||||
}.getOrElse {
|
||||
val checkpointLocation = getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
|
||||
Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath
|
||||
}
|
||||
|
||||
|
@ -378,21 +371,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
className = source,
|
||||
options = extraOptions.toMap,
|
||||
partitionColumns = normalizedParCols.getOrElse(Nil))
|
||||
|
||||
val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
|
||||
val checkpointLocation = extraOptions.get("checkpointLocation")
|
||||
.orElse {
|
||||
df.sparkSession.sessionState.conf.checkpointLocation.map { l =>
|
||||
new Path(l, queryName).toUri.toString
|
||||
}
|
||||
}.getOrElse {
|
||||
throw new AnalysisException("checkpointLocation must be specified either " +
|
||||
"through option() or SQLConf")
|
||||
}
|
||||
|
||||
df.sparkSession.sessionState.continuousQueryManager.startQuery(
|
||||
queryName,
|
||||
checkpointLocation,
|
||||
getCheckpointLocation(queryName, failIfNotSet = true).get,
|
||||
df,
|
||||
dataSource.createSink(outputMode),
|
||||
outputMode,
|
||||
|
@ -400,6 +382,94 @@ final class DataFrameWriter private[sql](df: DataFrame) {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* Starts the execution of the streaming query, which will continually send results to the given
|
||||
* [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data
|
||||
* generated by the [[DataFrame]]/[[Dataset]] to an external system. The returned The returned
|
||||
* [[ContinuousQuery]] object can be used to interact with the stream.
|
||||
*
|
||||
* Scala example:
|
||||
* {{{
|
||||
* datasetOfString.write.foreach(new ForeachWriter[String] {
|
||||
*
|
||||
* def open(partitionId: Long, version: Long): Boolean = {
|
||||
* // open connection
|
||||
* }
|
||||
*
|
||||
* def process(record: String) = {
|
||||
* // write string to connection
|
||||
* }
|
||||
*
|
||||
* def close(errorOrNull: Throwable): Unit = {
|
||||
* // close the connection
|
||||
* }
|
||||
* })
|
||||
* }}}
|
||||
*
|
||||
* Java example:
|
||||
* {{{
|
||||
* datasetOfString.write().foreach(new ForeachWriter<String>() {
|
||||
*
|
||||
* @Override
|
||||
* public boolean open(long partitionId, long version) {
|
||||
* // open connection
|
||||
* }
|
||||
*
|
||||
* @Override
|
||||
* public void process(String value) {
|
||||
* // write string to connection
|
||||
* }
|
||||
*
|
||||
* @Override
|
||||
* public void close(Throwable errorOrNull) {
|
||||
* // close the connection
|
||||
* }
|
||||
* });
|
||||
* }}}
|
||||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
@Experimental
|
||||
def foreach(writer: ForeachWriter[T]): ContinuousQuery = {
|
||||
assertNotBucketed("foreach")
|
||||
assertStreaming(
|
||||
"foreach() can only be called on streaming Datasets/DataFrames.")
|
||||
|
||||
val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
|
||||
val sink = new ForeachSink[T](ds.sparkSession.sparkContext.clean(writer))(ds.exprEnc)
|
||||
df.sparkSession.sessionState.continuousQueryManager.startQuery(
|
||||
queryName,
|
||||
getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
|
||||
Utils.createTempDir(namePrefix = "foreach.stream").getCanonicalPath
|
||||
},
|
||||
df,
|
||||
sink,
|
||||
outputMode,
|
||||
trigger)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the checkpointLocation for a query. If `failIfNotSet` is `true` but the checkpoint
|
||||
* location is not set, [[AnalysisException]] will be thrown. If `failIfNotSet` is `false`, `None`
|
||||
* will be returned if the checkpoint location is not set.
|
||||
*/
|
||||
private def getCheckpointLocation(queryName: String, failIfNotSet: Boolean): Option[String] = {
|
||||
val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
|
||||
new Path(userSpecified).toUri.toString
|
||||
}.orElse {
|
||||
df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION).map { location =>
|
||||
new Path(location, queryName).toUri.toString
|
||||
}
|
||||
}
|
||||
if (failIfNotSet && checkpointLocation.isEmpty) {
|
||||
throw new AnalysisException("checkpointLocation must be specified either " +
|
||||
"""through option("checkpointLocation", ...) or """ +
|
||||
s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""")
|
||||
}
|
||||
checkpointLocation
|
||||
}
|
||||
|
||||
/**
|
||||
* Inserts the content of the [[DataFrame]] to the specified table. It requires that
|
||||
* the schema of the [[DataFrame]] is the same as the schema of the table.
|
||||
|
|
|
@ -2400,7 +2400,7 @@ class Dataset[T] private[sql](
|
|||
* @since 1.6.0
|
||||
*/
|
||||
@Experimental
|
||||
def write: DataFrameWriter = new DataFrameWriter(toDF())
|
||||
def write: DataFrameWriter[T] = new DataFrameWriter[T](this)
|
||||
|
||||
/**
|
||||
* Returns the content of the Dataset as a Dataset of JSON strings.
|
||||
|
|
105
sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
Normal file
105
sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
Normal file
|
@ -0,0 +1,105 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.sql.streaming.ContinuousQuery
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* A class to consume data generated by a [[ContinuousQuery]]. Typically this is used to send the
|
||||
* generated data to external systems. Each partition will use a new deserialized instance, so you
|
||||
* usually should do all the initialization (e.g. opening a connection or initiating a transaction)
|
||||
* in the `open` method.
|
||||
*
|
||||
* Scala example:
|
||||
* {{{
|
||||
* datasetOfString.write.foreach(new ForeachWriter[String] {
|
||||
*
|
||||
* def open(partitionId: Long, version: Long): Boolean = {
|
||||
* // open connection
|
||||
* }
|
||||
*
|
||||
* def process(record: String) = {
|
||||
* // write string to connection
|
||||
* }
|
||||
*
|
||||
* def close(errorOrNull: Throwable): Unit = {
|
||||
* // close the connection
|
||||
* }
|
||||
* })
|
||||
* }}}
|
||||
*
|
||||
* Java example:
|
||||
* {{{
|
||||
* datasetOfString.write().foreach(new ForeachWriter<String>() {
|
||||
*
|
||||
* @Override
|
||||
* public boolean open(long partitionId, long version) {
|
||||
* // open connection
|
||||
* }
|
||||
*
|
||||
* @Override
|
||||
* public void process(String value) {
|
||||
* // write string to connection
|
||||
* }
|
||||
*
|
||||
* @Override
|
||||
* public void close(Throwable errorOrNull) {
|
||||
* // close the connection
|
||||
* }
|
||||
* });
|
||||
* }}}
|
||||
* @since 2.0.0
|
||||
*/
|
||||
@Experimental
|
||||
abstract class ForeachWriter[T] extends Serializable {
|
||||
|
||||
/**
|
||||
* Called when starting to process one partition of new data in the executor. The `version` is
|
||||
* for data deduplication when there are failures. When recovering from a failure, some data may
|
||||
* be generated multiple times but they will always have the same version.
|
||||
*
|
||||
* If this method finds using the `partitionId` and `version` that this partition has already been
|
||||
* processed, it can return `false` to skip the further data processing. However, `close` still
|
||||
* will be called for cleaning up resources.
|
||||
*
|
||||
* @param partitionId the partition id.
|
||||
* @param version a unique id for data deduplication.
|
||||
* @return `true` if the corresponding partition and version id should be processed. `false`
|
||||
* indicates the partition should be skipped.
|
||||
*/
|
||||
def open(partitionId: Long, version: Long): Boolean
|
||||
|
||||
/**
|
||||
* Called to process the data in the executor side. This method will be called only when `open`
|
||||
* returns `true`.
|
||||
*/
|
||||
def process(value: T): Unit
|
||||
|
||||
/**
|
||||
* Called when stopping to process one partition of new data in the executor side. This is
|
||||
* guaranteed to be called either `open` returns `true` or `false`. However,
|
||||
* `close` won't be called in the following cases:
|
||||
* - JVM crashes without throwing a `Throwable`
|
||||
* - `open` throws a `Throwable`.
|
||||
*
|
||||
* @param errorOrNull the error thrown during processing data or null if there was no error.
|
||||
*/
|
||||
def close(errorOrNull: Throwable): Unit
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
|
||||
|
||||
/**
|
||||
* A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
|
||||
* [[ForeachWriter]].
|
||||
*
|
||||
* @param writer The [[ForeachWriter]] to process all data.
|
||||
* @tparam T The expected type of the sink.
|
||||
*/
|
||||
class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {
|
||||
|
||||
override def addBatch(batchId: Long, data: DataFrame): Unit = {
|
||||
data.as[T].foreachPartition { iter =>
|
||||
if (writer.open(TaskContext.getPartitionId(), batchId)) {
|
||||
var isFailed = false
|
||||
try {
|
||||
while (iter.hasNext) {
|
||||
writer.process(iter.next())
|
||||
}
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
isFailed = true
|
||||
writer.close(e)
|
||||
}
|
||||
if (!isFailed) {
|
||||
writer.close(null)
|
||||
}
|
||||
} else {
|
||||
writer.close(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,141 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming
|
||||
|
||||
import java.util.concurrent.ConcurrentLinkedQueue
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.sql.ForeachWriter
|
||||
import org.apache.spark.sql.streaming.StreamTest
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
after {
|
||||
sqlContext.streams.active.foreach(_.stop())
|
||||
}
|
||||
|
||||
test("foreach") {
|
||||
withTempDir { checkpointDir =>
|
||||
val input = MemoryStream[Int]
|
||||
val query = input.toDS().repartition(2).write
|
||||
.option("checkpointLocation", checkpointDir.getCanonicalPath)
|
||||
.foreach(new TestForeachWriter())
|
||||
input.addData(1, 2, 3, 4)
|
||||
query.processAllAvailable()
|
||||
|
||||
val expectedEventsForPartition0 = Seq(
|
||||
ForeachSinkSuite.Open(partition = 0, version = 0),
|
||||
ForeachSinkSuite.Process(value = 1),
|
||||
ForeachSinkSuite.Process(value = 3),
|
||||
ForeachSinkSuite.Close(None)
|
||||
)
|
||||
val expectedEventsForPartition1 = Seq(
|
||||
ForeachSinkSuite.Open(partition = 1, version = 0),
|
||||
ForeachSinkSuite.Process(value = 2),
|
||||
ForeachSinkSuite.Process(value = 4),
|
||||
ForeachSinkSuite.Close(None)
|
||||
)
|
||||
|
||||
val allEvents = ForeachSinkSuite.allEvents()
|
||||
assert(allEvents.size === 2)
|
||||
assert {
|
||||
allEvents === Seq(expectedEventsForPartition0, expectedEventsForPartition1) ||
|
||||
allEvents === Seq(expectedEventsForPartition1, expectedEventsForPartition0)
|
||||
}
|
||||
query.stop()
|
||||
}
|
||||
}
|
||||
|
||||
test("foreach with error") {
|
||||
withTempDir { checkpointDir =>
|
||||
val input = MemoryStream[Int]
|
||||
val query = input.toDS().repartition(1).write
|
||||
.option("checkpointLocation", checkpointDir.getCanonicalPath)
|
||||
.foreach(new TestForeachWriter() {
|
||||
override def process(value: Int): Unit = {
|
||||
super.process(value)
|
||||
throw new RuntimeException("error")
|
||||
}
|
||||
})
|
||||
input.addData(1, 2, 3, 4)
|
||||
query.processAllAvailable()
|
||||
|
||||
val allEvents = ForeachSinkSuite.allEvents()
|
||||
assert(allEvents.size === 1)
|
||||
assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0))
|
||||
assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1))
|
||||
val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close]
|
||||
assert(errorEvent.error.get.isInstanceOf[RuntimeException])
|
||||
assert(errorEvent.error.get.getMessage === "error")
|
||||
query.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** A global object to collect events in the executor */
|
||||
object ForeachSinkSuite {
|
||||
|
||||
trait Event
|
||||
|
||||
case class Open(partition: Long, version: Long) extends Event
|
||||
|
||||
case class Process[T](value: T) extends Event
|
||||
|
||||
case class Close(error: Option[Throwable]) extends Event
|
||||
|
||||
private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]()
|
||||
|
||||
def addEvents(events: Seq[Event]): Unit = {
|
||||
_allEvents.add(events)
|
||||
}
|
||||
|
||||
def allEvents(): Seq[Seq[Event]] = {
|
||||
_allEvents.toArray(new Array[Seq[Event]](_allEvents.size()))
|
||||
}
|
||||
|
||||
def clear(): Unit = {
|
||||
_allEvents.clear()
|
||||
}
|
||||
}
|
||||
|
||||
/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */
|
||||
class TestForeachWriter extends ForeachWriter[Int] {
|
||||
ForeachSinkSuite.clear()
|
||||
|
||||
private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]()
|
||||
|
||||
override def open(partitionId: Long, version: Long): Boolean = {
|
||||
events += ForeachSinkSuite.Open(partition = partitionId, version = version)
|
||||
true
|
||||
}
|
||||
|
||||
override def process(value: Int): Unit = {
|
||||
events += ForeachSinkSuite.Process(value)
|
||||
}
|
||||
|
||||
override def close(errorOrNull: Throwable): Unit = {
|
||||
events += ForeachSinkSuite.Close(error = Option(errorOrNull))
|
||||
ForeachSinkSuite.addEvents(events)
|
||||
}
|
||||
}
|
|
@ -238,7 +238,9 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
|
|||
shuffleLeft: Boolean,
|
||||
shuffleRight: Boolean): Unit = {
|
||||
withTable("bucketed_table1", "bucketed_table2") {
|
||||
def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = {
|
||||
def withBucket(
|
||||
writer: DataFrameWriter[Row],
|
||||
bucketSpec: Option[BucketSpec]): DataFrameWriter[Row] = {
|
||||
bucketSpec.map { spec =>
|
||||
writer.bucketBy(
|
||||
spec.numBuckets,
|
||||
|
|
Loading…
Reference in a new issue