[SPARK-19447] Make Range operator generate "recordsRead" metric
## What changes were proposed in this pull request? The Range was modified to produce "recordsRead" metric instead of "generated rows". The tests were updated and partially moved to SQLMetricsSuite. ## How was this patch tested? Unit tests. Author: Ala Luszczak <ala@databricks.com> Closes #16960 from ala/range-records-read.
This commit is contained in:
parent
729ce37032
commit
b486ffc86d
|
@ -31,6 +31,7 @@ import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleComp
|
|||
import org.codehaus.janino.util.ClassFile
|
||||
|
||||
import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
|
||||
import org.apache.spark.executor.InputMetrics
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.metrics.source.CodegenMetrics
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
|
@ -933,7 +934,8 @@ object CodeGenerator extends Logging {
|
|||
classOf[UnsafeMapData].getName,
|
||||
classOf[Expression].getName,
|
||||
classOf[TaskContext].getName,
|
||||
classOf[TaskKilledException].getName
|
||||
classOf[TaskKilledException].getName,
|
||||
classOf[InputMetrics].getName
|
||||
))
|
||||
evaluator.setExtendedClass(classOf[GeneratedClass])
|
||||
|
||||
|
|
|
@ -365,6 +365,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
|||
|
||||
val taskContext = ctx.freshName("taskContext")
|
||||
ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();")
|
||||
val inputMetrics = ctx.freshName("inputMetrics")
|
||||
ctx.addMutableState("InputMetrics", inputMetrics,
|
||||
s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();")
|
||||
|
||||
// In order to periodically update the metrics without inflicting performance penalty, this
|
||||
// operator produces elements in batches. After a batch is complete, the metrics are updated
|
||||
|
@ -460,7 +463,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
|||
| if ($nextBatchTodo == 0) break;
|
||||
| }
|
||||
| $numOutput.add($nextBatchTodo);
|
||||
| $numGenerated.add($nextBatchTodo);
|
||||
| $inputMetrics.incRecordsRead($nextBatchTodo);
|
||||
|
|
||||
| $batchEnd += $nextBatchTodo * ${step}L;
|
||||
| }
|
||||
|
@ -469,7 +472,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
|||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
val numOutputRows = longMetric("numOutputRows")
|
||||
val numGeneratedRows = longMetric("numGeneratedRows")
|
||||
sqlContext
|
||||
.sparkContext
|
||||
.parallelize(0 until numSlices, numSlices)
|
||||
|
@ -488,10 +490,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
|||
val safePartitionEnd = getSafeMargin(partitionEnd)
|
||||
val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
|
||||
val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
|
||||
val taskContext = TaskContext.get()
|
||||
|
||||
val iter = new Iterator[InternalRow] {
|
||||
private[this] var number: Long = safePartitionStart
|
||||
private[this] var overflow: Boolean = false
|
||||
private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics
|
||||
|
||||
override def hasNext =
|
||||
if (!overflow) {
|
||||
|
@ -513,12 +517,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
|||
}
|
||||
|
||||
numOutputRows += 1
|
||||
numGeneratedRows += 1
|
||||
inputMetrics.incRecordsRead(1)
|
||||
unsafeRow.setLong(0, ret)
|
||||
unsafeRow
|
||||
}
|
||||
}
|
||||
new InterruptibleIterator(TaskContext.get(), iter)
|
||||
new InterruptibleIterator(taskContext, iter)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,131 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import java.io.File
|
||||
|
||||
import org.scalatest.concurrent.Eventually
|
||||
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
|
||||
import org.apache.spark.sql.{DataFrame, QueryTest}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually {
|
||||
|
||||
test("Range query input/output/generated metrics") {
|
||||
val numRows = 150L
|
||||
val numSelectedRows = 100L
|
||||
val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1).
|
||||
filter(x => x < numSelectedRows).toDF())
|
||||
|
||||
assert(res.recordsRead.sum === 0)
|
||||
assert(res.shuffleRecordsRead.sum === 0)
|
||||
assert(res.generatedRows === numRows :: Nil)
|
||||
assert(res.outputRows === numSelectedRows :: numRows :: Nil)
|
||||
}
|
||||
|
||||
test("Input/output/generated metrics with repartitioning") {
|
||||
val numRows = 100L
|
||||
val res = MetricsTestHelper.runAndGetMetrics(
|
||||
spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF())
|
||||
|
||||
assert(res.recordsRead.sum === 0)
|
||||
assert(res.shuffleRecordsRead.sum === numRows)
|
||||
assert(res.generatedRows === numRows :: Nil)
|
||||
assert(res.outputRows === 20 :: numRows :: Nil)
|
||||
}
|
||||
|
||||
test("Input/output/generated metrics with more repartitioning") {
|
||||
withTempDir { tempDir =>
|
||||
val dir = new File(tempDir, "pqS").getCanonicalPath
|
||||
|
||||
spark.range(10).write.parquet(dir)
|
||||
spark.read.parquet(dir).createOrReplaceTempView("pqS")
|
||||
|
||||
val res = MetricsTestHelper.runAndGetMetrics(
|
||||
spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2)
|
||||
.toDF()
|
||||
)
|
||||
|
||||
assert(res.recordsRead.sum == 10)
|
||||
assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150)
|
||||
assert(res.generatedRows == 30 :: Nil)
|
||||
assert(res.outputRows == 10 :: 30 :: 300 :: Nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object MetricsTestHelper {
|
||||
case class AggregatedMetricsResult(
|
||||
recordsRead: List[Long],
|
||||
shuffleRecordsRead: List[Long],
|
||||
generatedRows: List[Long],
|
||||
outputRows: List[Long])
|
||||
|
||||
private[this] def extractMetricValues(
|
||||
df: DataFrame,
|
||||
metricValues: Map[Long, String],
|
||||
metricName: String): List[Long] = {
|
||||
df.queryExecution.executedPlan.collect {
|
||||
case plan if plan.metrics.contains(metricName) =>
|
||||
metricValues(plan.metrics(metricName).id).toLong
|
||||
}.toList.sorted
|
||||
}
|
||||
|
||||
def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false):
|
||||
AggregatedMetricsResult = {
|
||||
val spark = df.sparkSession
|
||||
val sparkContext = spark.sparkContext
|
||||
|
||||
var recordsRead = List[Long]()
|
||||
var shuffleRecordsRead = List[Long]()
|
||||
val listener = new SparkListener() {
|
||||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
|
||||
if (taskEnd.taskMetrics != null) {
|
||||
recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead ::
|
||||
recordsRead
|
||||
shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead ::
|
||||
shuffleRecordsRead
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet
|
||||
|
||||
val prevUseWholeStageCodeGen =
|
||||
spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED)
|
||||
try {
|
||||
spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen)
|
||||
sparkContext.listenerBus.waitUntilEmpty(10000)
|
||||
sparkContext.addSparkListener(listener)
|
||||
df.collect()
|
||||
sparkContext.listenerBus.waitUntilEmpty(10000)
|
||||
} finally {
|
||||
spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen)
|
||||
}
|
||||
|
||||
val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head
|
||||
val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
|
||||
val outputRes = extractMetricValues(df, metricValues, "numOutputRows")
|
||||
val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows")
|
||||
|
||||
AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes, outputRes)
|
||||
}
|
||||
}
|
|
@ -17,7 +17,12 @@
|
|||
|
||||
package org.apache.spark.sql.execution.metric
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
|
||||
import org.apache.spark.sql.execution.SparkPlanInfo
|
||||
|
@ -309,4 +314,103 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
|
|||
assert(metricInfoDeser.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER))
|
||||
}
|
||||
|
||||
test("range metrics") {
|
||||
val res1 = InputOutputMetricsHelper.run(
|
||||
spark.range(30).filter(x => x % 3 == 0).toDF()
|
||||
)
|
||||
assert(res1 === (30L, 0L, 30L) :: Nil)
|
||||
|
||||
val res2 = InputOutputMetricsHelper.run(
|
||||
spark.range(150).repartition(4).filter(x => x < 10).toDF()
|
||||
)
|
||||
assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil)
|
||||
|
||||
withTempDir { tempDir =>
|
||||
val dir = new File(tempDir, "pqS").getCanonicalPath
|
||||
|
||||
spark.range(10).write.parquet(dir)
|
||||
spark.read.parquet(dir).createOrReplaceTempView("pqS")
|
||||
|
||||
val res3 = InputOutputMetricsHelper.run(
|
||||
spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF()
|
||||
)
|
||||
// The query above is executed in the following stages:
|
||||
// 1. sql("select * from pqS") => (10, 0, 10)
|
||||
// 2. range(30) => (30, 0, 30)
|
||||
// 3. crossJoin(...) of 1. and 2. => (0, 30, 300)
|
||||
// 4. shuffle & return results => (0, 300, 0)
|
||||
assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object InputOutputMetricsHelper {
|
||||
private class InputOutputMetricsListener extends SparkListener {
|
||||
private case class MetricsResult(
|
||||
var recordsRead: Long = 0L,
|
||||
var shuffleRecordsRead: Long = 0L,
|
||||
var sumMaxOutputRows: Long = 0L)
|
||||
|
||||
private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult]
|
||||
|
||||
def reset(): Unit = {
|
||||
stageIdToMetricsResult.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a list of recorded metrics aggregated per stage.
|
||||
*
|
||||
* The list is sorted in the ascending order on the stageId.
|
||||
* For each recorded stage, the following tuple is returned:
|
||||
* - sum of inputMetrics.recordsRead for all the tasks in the stage
|
||||
* - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage
|
||||
* - sum of the highest values of "number of output rows" metric for all the tasks in the stage
|
||||
*/
|
||||
def getResults(): List[(Long, Long, Long)] = {
|
||||
stageIdToMetricsResult.keySet.toList.sorted.map { stageId =>
|
||||
val res = stageIdToMetricsResult(stageId)
|
||||
(res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)
|
||||
}
|
||||
}
|
||||
|
||||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
|
||||
val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult())
|
||||
|
||||
res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
|
||||
res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead
|
||||
|
||||
var maxOutputRows = 0L
|
||||
for (accum <- taskEnd.taskMetrics.externalAccums) {
|
||||
val info = accum.toInfo(Some(accum.value), None)
|
||||
if (info.name.toString.contains("number of output rows")) {
|
||||
info.update match {
|
||||
case Some(n: Number) =>
|
||||
if (n.longValue() > maxOutputRows) {
|
||||
maxOutputRows = n.longValue()
|
||||
}
|
||||
case _ => // Ignore.
|
||||
}
|
||||
}
|
||||
}
|
||||
res.sumMaxOutputRows += maxOutputRows
|
||||
}
|
||||
}
|
||||
|
||||
// Run df.collect() and return aggregated metrics for each stage.
|
||||
def run(df: DataFrame): List[(Long, Long, Long)] = {
|
||||
val spark = df.sparkSession
|
||||
val sparkContext = spark.sparkContext
|
||||
val listener = new InputOutputMetricsListener()
|
||||
sparkContext.addSparkListener(listener)
|
||||
|
||||
try {
|
||||
sparkContext.listenerBus.waitUntilEmpty(5000)
|
||||
listener.reset()
|
||||
df.collect()
|
||||
sparkContext.listenerBus.waitUntilEmpty(5000)
|
||||
} finally {
|
||||
sparkContext.removeSparkListener(listener)
|
||||
}
|
||||
listener.getResults()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
|
|||
import org.apache.spark.sql.execution.command.ExplainCommand
|
||||
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
||||
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
|
||||
import org.apache.spark.sql.execution.MetricsTestHelper
|
||||
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
|
||||
import org.apache.spark.sql.sources._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -917,13 +917,10 @@ class JDBCSuite extends SparkFunSuite
|
|||
assert(e2.contains("User specified schema not supported with `jdbc`"))
|
||||
}
|
||||
|
||||
test("Input/generated/output metrics on JDBC") {
|
||||
test("Checking metrics correctness with JDBC") {
|
||||
val foobarCnt = spark.table("foobar").count()
|
||||
val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF())
|
||||
assert(res.recordsRead === foobarCnt :: Nil)
|
||||
assert(res.shuffleRecordsRead.sum === 0)
|
||||
assert(res.generatedRows.isEmpty)
|
||||
assert(res.outputRows === foobarCnt :: Nil)
|
||||
val res = InputOutputMetricsHelper.run(sql("SELECT * FROM foobar").toDF())
|
||||
assert(res === (foobarCnt, 0L, foobarCnt) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-19318: Connection properties keys should be case-sensitive.") {
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution
|
|||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.sql.execution.MetricsTestHelper
|
||||
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
|
||||
import org.apache.spark.sql.hive.test.TestHive
|
||||
|
||||
/**
|
||||
|
@ -49,21 +49,15 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {
|
|||
|
||||
createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part")
|
||||
|
||||
test("Test input/generated/output metrics") {
|
||||
test("Checking metrics correctness") {
|
||||
import TestHive._
|
||||
|
||||
val episodesCnt = sql("select * from episodes").count()
|
||||
val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF())
|
||||
assert(episodesRes.recordsRead === episodesCnt :: Nil)
|
||||
assert(episodesRes.shuffleRecordsRead.sum === 0)
|
||||
assert(episodesRes.generatedRows.isEmpty)
|
||||
assert(episodesRes.outputRows === episodesCnt :: Nil)
|
||||
val episodesRes = InputOutputMetricsHelper.run(sql("select * from episodes").toDF())
|
||||
assert(episodesRes === (episodesCnt, 0L, episodesCnt) :: Nil)
|
||||
|
||||
val serdeinsCnt = sql("select * from serdeins").count()
|
||||
val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF())
|
||||
assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
|
||||
assert(serdeinsRes.shuffleRecordsRead.sum === 0)
|
||||
assert(serdeinsRes.generatedRows.isEmpty)
|
||||
assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
|
||||
val serdeinsRes = InputOutputMetricsHelper.run(sql("select * from serdeins").toDF())
|
||||
assert(serdeinsRes === (serdeinsCnt, 0L, serdeinsCnt) :: Nil)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue