diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0fd10c157e..6eb452a1bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream} +import scala.annotation.varargs import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.reflect.runtime.universe.TypeTag @@ -1947,6 +1948,32 @@ class Dataset[T] private[sql]( CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan) } + /** + * Observe (named) metrics through an `org.apache.spark.sql.Observation` instance. + * This is equivalent to calling `observe(String, Column, Column*)` but does not require + * adding `org.apache.spark.sql.util.QueryExecutionListener` to the spark session. + * This method does not support streaming datasets. + * + * A user can retrieve the metrics by accessing `org.apache.spark.sql.Observation.get`. + * + * {{{ + * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it + * val observation = Observation("my_metrics") + * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid")) + * observed_ds.write.parquet("ds.parquet") + * val metrics = observation.get + * }}} + * + * @throws IllegalArgumentException If this is a streaming Dataset (this.isStreaming == true) + * + * @group typedrel + * @since 3.3.0 + */ + @varargs + def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] = { + observation.on(this, expr, exprs: _*) + } + /** * Returns a new Dataset by taking the first `n` rows. The difference between this function * and `head` is that `head` is an action and returns an array (by triggering query execution) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala new file mode 100644 index 0000000000..807d72acb9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.UUID + +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + + +/** + * Helper class to simplify usage of `Dataset.observe(String, Column, Column*)`: + * + * {{{ + * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it + * val observation = Observation("my metrics") + * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid")) + * observed_ds.write.parquet("ds.parquet") + * val metrics = observation.get + * }}} + * + * This collects the metrics while the first action is executed on the observed dataset. Subsequent + * actions do not modify the metrics returned by [[get]]. Retrieval of the metric via [[get]] + * blocks until the first action has finished and metrics become available. + * + * This class does not support streaming datasets. + * + * @param name name of the metric + * @since 3.3.0 + */ +class Observation(name: String) { + + /** + * Create an Observation instance without providing a name. This generates a random name. + */ + def this() = this(UUID.randomUUID().toString) + + private val listener: ObservationListener = ObservationListener(this) + + @volatile private var sparkSession: Option[SparkSession] = None + + @volatile private var row: Option[Row] = None + + /** + * Attach this observation to the given [[Dataset]] to observe aggregation expressions. + * + * @param ds dataset + * @param expr first aggregation expression + * @param exprs more aggregation expressions + * @tparam T dataset type + * @return observed dataset + * @throws IllegalArgumentException If this is a streaming Dataset (ds.isStreaming == true) + */ + private[spark] def on[T](ds: Dataset[T], expr: Column, exprs: Column*): Dataset[T] = { + if (ds.isStreaming) { + throw new IllegalArgumentException("Observation does not support streaming Datasets") + } + register(ds.sparkSession) + ds.observe(name, expr, exprs: _*) + } + + /** + * Get the observed metrics. This waits for the observed dataset to finish its first action. + * Only the result of the first action is available. Subsequent actions do not modify the result. + * + * @return the observed metrics as a [[Row]] + * @throws InterruptedException interrupted while waiting + */ + @throws[InterruptedException] + def get: Row = { + synchronized { + // we need to loop as wait might return without us calling notify + // https://en.wikipedia.org/w/index.php?title=Spurious_wakeup&oldid=992601610 + while (this.row.isEmpty) { + wait() + } + } + + this.row.get + } + + private def register(sparkSession: SparkSession): Unit = { + // makes this class thread-safe: + // only the first thread entering this block can set sparkSession + // all other threads will see the exception, as it is only allowed to do this once + synchronized { + if (this.sparkSession.isDefined) { + throw new IllegalArgumentException("An Observation can be used with a Dataset only once") + } + this.sparkSession = Some(sparkSession) + } + + sparkSession.listenerManager.register(this.listener) + } + + private def unregister(): Unit = { + this.sparkSession.foreach(_.listenerManager.unregister(this.listener)) + } + + private[spark] def onFinish(qe: QueryExecution): Unit = { + synchronized { + if (this.row.isEmpty) { + this.row = qe.observedMetrics.get(name) + if (this.row.isDefined) { + notifyAll() + unregister() + } + } + } + } + +} + +private[sql] case class ObservationListener(observation: Observation) + extends QueryExecutionListener { + + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = + observation.onFinish(qe) + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = + observation.onFinish(qe) + +} + +/** + * (Scala-specific) Create instances of Observation via Scala `apply`. + * @since 3.3.0 + */ +object Observation { + + /** + * Observation constructor for creating an anonymous observation. + */ + def apply(): Observation = new Observation() + + /** + * Observation constructor for creating a named observation. + */ + def apply(name: String): Observation = new Observation(name) + +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index da7c62251b..e469792b75 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -34,6 +34,7 @@ import org.junit.*; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Observation; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.expressions.UserDefinedFunction; @@ -523,4 +524,50 @@ public class JavaDataFrameSuite { .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); Assert.assertArrayEquals(expected, result); } + + /** + * Tests the Java API of Observation and Dataset.observe(Observation, Column, Column*). + */ + @Test + public void testObservation() { + Observation namedObservation = new Observation("named"); + Observation unnamedObservation = new Observation(); + + Dataset df = spark + .range(100) + .observe( + namedObservation, + min(col("id")).as("min_val"), + max(col("id")).as("max_val"), + sum(col("id")).as("sum_val"), + count(when(pmod(col("id"), lit(2)).$eq$eq$eq(0), 1)).as("num_even") + ) + .observe( + unnamedObservation, + avg(col("id")).cast("int").as("avg_val") + ); + + df.collect(); + List namedMetrics = null; + List unnamedMetrics = null; + + try { + namedMetrics = JavaConverters.seqAsJavaList(namedObservation.get().toSeq()); + unnamedMetrics = JavaConverters.seqAsJavaList(unnamedObservation.get().toSeq()); + } catch (InterruptedException e) { + Assert.fail(); + } + Assert.assertEquals(Arrays.asList(0L, 99L, 4950L, 50L), namedMetrics); + Assert.assertEquals(Arrays.asList(49), unnamedMetrics); + + // we can get the result multiple times + try { + namedMetrics = JavaConverters.seqAsJavaList(namedObservation.get().toSeq()); + unnamedMetrics = JavaConverters.seqAsJavaList(unnamedObservation.get().toSeq()); + } catch (InterruptedException e) { + Assert.fail(); + } + Assert.assertEquals(Arrays.asList(0L, 99L, 4950L, 50L), namedMetrics); + Assert.assertEquals(Arrays.asList(49), unnamedMetrics); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1e3d219220..80416f5933 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCod import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.expressions.{Aggregator, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2382,6 +2383,49 @@ class DataFrameSuite extends QueryTest } } + test("SPARK-34806: observation on datasets") { + val namedObservation = Observation("named") + val unnamedObservation = Observation() + + val df = spark + .range(100) + .observe( + namedObservation, + min($"id").as("min_val"), + max($"id").as("max_val"), + sum($"id").as("sum_val"), + count(when($"id" % 2 === 0, 1)).as("num_even") + ) + .observe( + unnamedObservation, + avg($"id").cast("int").as("avg_val") + ) + + def checkMetrics(namedMetric: Row, unnamedMetric: Row): Unit = { + assert(namedMetric === Row(0L, 99L, 4950L, 50L)) + assert(unnamedMetric === Row(49)) + } + + df.collect() + // we can get the result multiple times + checkMetrics(namedObservation.get, unnamedObservation.get) + checkMetrics(namedObservation.get, unnamedObservation.get) + + // an observation can be used only once + val err = intercept[IllegalArgumentException] { + spark.range(100).observe(namedObservation, sum($"id").as("sum_val")) + } + assert(err.getMessage.contains("An Observation can be used with a Dataset only once")) + + // streaming datasets are not supported + val streamDf = new MemoryStream[Int](0, sqlContext).toDF() + val streamObservation = Observation("stream") + val streamErr = intercept[IllegalArgumentException] { + streamDf.observe(streamObservation, avg($"value").cast("int").as("avg_val")) + } + assert(streamErr.getMessage.contains("Observation does not support streaming Datasets")) + } + test("SPARK-25159: json schema inference should only trigger one job") { withTempPath { path => // This test is to prove that the `JsonInferSchema` does not use `RDD#toLocalIterator` which