[SPARK-18120][SPARK-19557][SQL] Call QueryExecutionListener callback methods for DataFrameWriter methods

## What changes were proposed in this pull request?

We only notify `QueryExecutionListener` for several `Dataset` operations, e.g. collect, take, etc. We should also do the notification for `DataFrameWriter` operations.

## How was this patch tested?

new regression test

close https://github.com/apache/spark/pull/16664

Author: Wenchen Fan <wenchen@databricks.com>

Closes #16962 from cloud-fan/insert.
This commit is contained in:
Wenchen Fan 2017-02-16 21:09:14 -08:00
parent 21fde57f15
commit 54d23599df
3 changed files with 142 additions and 16 deletions

View file

@ -25,9 +25,9 @@ import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
@ -211,13 +211,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
assertNotBucketed("save")
val dataSource = DataSource(
df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap)
dataSource.write(mode, df)
runCommand(df.sparkSession, "save") {
SaveIntoDataSourceCommand(
query = df.logicalPlan,
provider = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap,
mode = mode)
}
}
/**
@ -260,13 +262,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
)
}
df.sparkSession.sessionState.executePlan(
runCommand(df.sparkSession, "insertInto") {
InsertIntoTable(
table = UnresolvedRelation(tableIdent),
partition = Map.empty[String, Option[String]],
query = df.logicalPlan,
overwrite = mode == SaveMode.Overwrite,
ifNotExists = false)).toRdd
ifNotExists = false)
}
}
private def getBucketSpec: Option[BucketSpec] = {
@ -389,10 +392,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
schema = new StructType,
provider = Some(source),
partitionColumnNames = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec
)
df.sparkSession.sessionState.executePlan(
CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd
bucketSpec = getBucketSpec)
runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan)))
}
/**
@ -573,6 +575,25 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
format("csv").save(path)
}
/**
* Wrap a DataFrameWriter action to track the QueryExecution and time cost, then report to the
* user-registered callback functions.
*/
private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = {
val qe = session.sessionState.executePlan(command)
try {
val start = System.nanoTime()
// call `QueryExecution.toRDD` to trigger the execution of commands.
qe.toRdd
val end = System.nanoTime()
session.listenerManager.onSuccess(name, qe, end - start)
} catch {
case e: Exception =>
session.listenerManager.onFailure(name, qe, e)
throw e
}
}
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////

View file

@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
/**
* Saves the results of `query` in to a data source.
*
* Note that this command is different from [[InsertIntoDataSourceCommand]]. This command will call
* `CreatableRelationProvider.createRelation` to write out the data, while
* [[InsertIntoDataSourceCommand]] calls `InsertableRelation.insert`. Ideally these 2 data source
* interfaces should do the same thing, but as we've already published these 2 interfaces and the
* implementations may have different logic, we have to keep these 2 different commands.
*/
case class SaveIntoDataSourceCommand(
query: LogicalPlan,
provider: String,
partitionColumns: Seq[String],
options: Map[String, String],
mode: SaveMode) extends RunnableCommand {
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)
override def run(sparkSession: SparkSession): Seq[Row] = {
DataSource(
sparkSession,
className = provider,
partitionColumns = partitionColumns,
options = options).write(mode, Dataset.ofRows(sparkSession, query))
Seq.empty[Row]
}
}

View file

@ -20,9 +20,11 @@ package org.apache.spark.sql.util
import scala.collection.mutable.ArrayBuffer
import org.apache.spark._
import org.apache.spark.sql.{functions, QueryTest}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
import org.apache.spark.sql.{functions, AnalysisException, QueryTest}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project}
import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.datasources.{CreateTable, SaveIntoDataSourceCommand}
import org.apache.spark.sql.test.SharedSQLContext
class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
@ -159,4 +161,55 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
spark.listenerManager.unregister(listener)
}
test("execute callback functions for DataFrameWriter") {
val commands = ArrayBuffer.empty[(String, LogicalPlan)]
val exceptions = ArrayBuffer.empty[(String, Exception)]
val listener = new QueryExecutionListener {
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
exceptions += funcName -> exception
}
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
commands += funcName -> qe.logical
}
}
spark.listenerManager.register(listener)
withTempPath { path =>
spark.range(10).write.format("json").save(path.getCanonicalPath)
assert(commands.length == 1)
assert(commands.head._1 == "save")
assert(commands.head._2.isInstanceOf[SaveIntoDataSourceCommand])
assert(commands.head._2.asInstanceOf[SaveIntoDataSourceCommand].provider == "json")
}
withTable("tab") {
sql("CREATE TABLE tab(i long) using parquet")
spark.range(10).write.insertInto("tab")
assert(commands.length == 2)
assert(commands(1)._1 == "insertInto")
assert(commands(1)._2.isInstanceOf[InsertIntoTable])
assert(commands(1)._2.asInstanceOf[InsertIntoTable].table
.asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab")
}
withTable("tab") {
spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab")
assert(commands.length == 3)
assert(commands(2)._1 == "saveAsTable")
assert(commands(2)._2.isInstanceOf[CreateTable])
assert(commands(2)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p"))
}
withTable("tab") {
sql("CREATE TABLE tab(i long) using parquet")
val e = intercept[AnalysisException] {
spark.range(10).select($"id", $"id").write.insertInto("tab")
}
assert(exceptions.length == 1)
assert(exceptions.head._1 == "insertInto")
assert(exceptions.head._2 == e)
}
}
}