[SPARK-18120][SPARK-19557][SQL] Call QueryExecutionListener callback methods for DataFrameWriter methods
## What changes were proposed in this pull request? We only notify `QueryExecutionListener` for several `Dataset` operations, e.g. collect, take, etc. We should also do the notification for `DataFrameWriter` operations. ## How was this patch tested? new regression test close https://github.com/apache/spark/pull/16664 Author: Wenchen Fan <wenchen@databricks.com> Closes #16962 from cloud-fan/insert.
This commit is contained in:
parent
21fde57f15
commit
54d23599df
|
@ -25,9 +25,9 @@ import org.apache.spark.annotation.InterfaceStability
|
|||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
|
||||
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
|
||||
import org.apache.spark.sql.execution.command.DDLUtils
|
||||
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
|
||||
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand}
|
||||
import org.apache.spark.sql.sources.BaseRelation
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
|
@ -211,13 +211,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
|
|||
}
|
||||
|
||||
assertNotBucketed("save")
|
||||
val dataSource = DataSource(
|
||||
df.sparkSession,
|
||||
className = source,
|
||||
partitionColumns = partitioningColumns.getOrElse(Nil),
|
||||
options = extraOptions.toMap)
|
||||
|
||||
dataSource.write(mode, df)
|
||||
runCommand(df.sparkSession, "save") {
|
||||
SaveIntoDataSourceCommand(
|
||||
query = df.logicalPlan,
|
||||
provider = source,
|
||||
partitionColumns = partitioningColumns.getOrElse(Nil),
|
||||
options = extraOptions.toMap,
|
||||
mode = mode)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -260,13 +262,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
|
|||
)
|
||||
}
|
||||
|
||||
df.sparkSession.sessionState.executePlan(
|
||||
runCommand(df.sparkSession, "insertInto") {
|
||||
InsertIntoTable(
|
||||
table = UnresolvedRelation(tableIdent),
|
||||
partition = Map.empty[String, Option[String]],
|
||||
query = df.logicalPlan,
|
||||
overwrite = mode == SaveMode.Overwrite,
|
||||
ifNotExists = false)).toRdd
|
||||
ifNotExists = false)
|
||||
}
|
||||
}
|
||||
|
||||
private def getBucketSpec: Option[BucketSpec] = {
|
||||
|
@ -389,10 +392,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
|
|||
schema = new StructType,
|
||||
provider = Some(source),
|
||||
partitionColumnNames = partitioningColumns.getOrElse(Nil),
|
||||
bucketSpec = getBucketSpec
|
||||
)
|
||||
df.sparkSession.sessionState.executePlan(
|
||||
CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd
|
||||
bucketSpec = getBucketSpec)
|
||||
|
||||
runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan)))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -573,6 +575,25 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
|
|||
format("csv").save(path)
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap a DataFrameWriter action to track the QueryExecution and time cost, then report to the
|
||||
* user-registered callback functions.
|
||||
*/
|
||||
private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = {
|
||||
val qe = session.sessionState.executePlan(command)
|
||||
try {
|
||||
val start = System.nanoTime()
|
||||
// call `QueryExecution.toRDD` to trigger the execution of commands.
|
||||
qe.toRdd
|
||||
val end = System.nanoTime()
|
||||
session.listenerManager.onSuccess(name, qe, end - start)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
session.listenerManager.onFailure(name, qe, e)
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////
|
||||
// Builder pattern config options
|
||||
///////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.datasources
|
||||
|
||||
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.execution.command.RunnableCommand
|
||||
|
||||
/**
|
||||
* Saves the results of `query` in to a data source.
|
||||
*
|
||||
* Note that this command is different from [[InsertIntoDataSourceCommand]]. This command will call
|
||||
* `CreatableRelationProvider.createRelation` to write out the data, while
|
||||
* [[InsertIntoDataSourceCommand]] calls `InsertableRelation.insert`. Ideally these 2 data source
|
||||
* interfaces should do the same thing, but as we've already published these 2 interfaces and the
|
||||
* implementations may have different logic, we have to keep these 2 different commands.
|
||||
*/
|
||||
case class SaveIntoDataSourceCommand(
|
||||
query: LogicalPlan,
|
||||
provider: String,
|
||||
partitionColumns: Seq[String],
|
||||
options: Map[String, String],
|
||||
mode: SaveMode) extends RunnableCommand {
|
||||
|
||||
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)
|
||||
|
||||
override def run(sparkSession: SparkSession): Seq[Row] = {
|
||||
DataSource(
|
||||
sparkSession,
|
||||
className = provider,
|
||||
partitionColumns = partitionColumns,
|
||||
options = options).write(mode, Dataset.ofRows(sparkSession, query))
|
||||
|
||||
Seq.empty[Row]
|
||||
}
|
||||
}
|
|
@ -20,9 +20,11 @@ package org.apache.spark.sql.util
|
|||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.sql.{functions, QueryTest}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
|
||||
import org.apache.spark.sql.{functions, AnalysisException, QueryTest}
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project}
|
||||
import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
|
||||
import org.apache.spark.sql.execution.datasources.{CreateTable, SaveIntoDataSourceCommand}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
||||
|
@ -159,4 +161,55 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
spark.listenerManager.unregister(listener)
|
||||
}
|
||||
|
||||
test("execute callback functions for DataFrameWriter") {
|
||||
val commands = ArrayBuffer.empty[(String, LogicalPlan)]
|
||||
val exceptions = ArrayBuffer.empty[(String, Exception)]
|
||||
val listener = new QueryExecutionListener {
|
||||
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
|
||||
exceptions += funcName -> exception
|
||||
}
|
||||
|
||||
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
|
||||
commands += funcName -> qe.logical
|
||||
}
|
||||
}
|
||||
spark.listenerManager.register(listener)
|
||||
|
||||
withTempPath { path =>
|
||||
spark.range(10).write.format("json").save(path.getCanonicalPath)
|
||||
assert(commands.length == 1)
|
||||
assert(commands.head._1 == "save")
|
||||
assert(commands.head._2.isInstanceOf[SaveIntoDataSourceCommand])
|
||||
assert(commands.head._2.asInstanceOf[SaveIntoDataSourceCommand].provider == "json")
|
||||
}
|
||||
|
||||
withTable("tab") {
|
||||
sql("CREATE TABLE tab(i long) using parquet")
|
||||
spark.range(10).write.insertInto("tab")
|
||||
assert(commands.length == 2)
|
||||
assert(commands(1)._1 == "insertInto")
|
||||
assert(commands(1)._2.isInstanceOf[InsertIntoTable])
|
||||
assert(commands(1)._2.asInstanceOf[InsertIntoTable].table
|
||||
.asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab")
|
||||
}
|
||||
|
||||
withTable("tab") {
|
||||
spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab")
|
||||
assert(commands.length == 3)
|
||||
assert(commands(2)._1 == "saveAsTable")
|
||||
assert(commands(2)._2.isInstanceOf[CreateTable])
|
||||
assert(commands(2)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p"))
|
||||
}
|
||||
|
||||
withTable("tab") {
|
||||
sql("CREATE TABLE tab(i long) using parquet")
|
||||
val e = intercept[AnalysisException] {
|
||||
spark.range(10).select($"id", $"id").write.insertInto("tab")
|
||||
}
|
||||
assert(exceptions.length == 1)
|
||||
assert(exceptions.head._1 == "insertInto")
|
||||
assert(exceptions.head._2 == e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue