[SPARK-19447] Make Range operator generate "recordsRead" metric

## What changes were proposed in this pull request?

The Range was modified to produce "recordsRead" metric instead of "generated rows". The tests were updated and partially moved to SQLMetricsSuite.

## How was this patch tested?

Unit tests.

Author: Ala Luszczak <ala@databricks.com>

Closes #16960 from ala/range-records-read.
This commit is contained in:
Ala Luszczak 2017-02-18 07:51:41 -08:00 committed by Reynold Xin
parent 729ce37032
commit b486ffc86d
6 changed files with 125 additions and 155 deletions

View file

@ -31,6 +31,7 @@ import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleComp
import org.codehaus.janino.util.ClassFile
import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
@ -933,7 +934,8 @@ object CodeGenerator extends Logging {
classOf[UnsafeMapData].getName,
classOf[Expression].getName,
classOf[TaskContext].getName,
classOf[TaskKilledException].getName
classOf[TaskKilledException].getName,
classOf[InputMetrics].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])

View file

@ -365,6 +365,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val taskContext = ctx.freshName("taskContext")
ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();")
val inputMetrics = ctx.freshName("inputMetrics")
ctx.addMutableState("InputMetrics", inputMetrics,
s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();")
// In order to periodically update the metrics without inflicting performance penalty, this
// operator produces elements in batches. After a batch is complete, the metrics are updated
@ -460,7 +463,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| if ($nextBatchTodo == 0) break;
| }
| $numOutput.add($nextBatchTodo);
| $numGenerated.add($nextBatchTodo);
| $inputMetrics.incRecordsRead($nextBatchTodo);
|
| $batchEnd += $nextBatchTodo * ${step}L;
| }
@ -469,7 +472,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val numGeneratedRows = longMetric("numGeneratedRows")
sqlContext
.sparkContext
.parallelize(0 until numSlices, numSlices)
@ -488,10 +490,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val safePartitionEnd = getSafeMargin(partitionEnd)
val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
val taskContext = TaskContext.get()
val iter = new Iterator[InternalRow] {
private[this] var number: Long = safePartitionStart
private[this] var overflow: Boolean = false
private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics
override def hasNext =
if (!overflow) {
@ -513,12 +517,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
}
numOutputRows += 1
numGeneratedRows += 1
inputMetrics.incRecordsRead(1)
unsafeRow.setLong(0, ret)
unsafeRow
}
}
new InterruptibleIterator(TaskContext.get(), iter)
new InterruptibleIterator(taskContext, iter)
}
}

View file

@ -1,131 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import java.io.File
import org.scalatest.concurrent.Eventually
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually {
test("Range query input/output/generated metrics") {
val numRows = 150L
val numSelectedRows = 100L
val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1).
filter(x => x < numSelectedRows).toDF())
assert(res.recordsRead.sum === 0)
assert(res.shuffleRecordsRead.sum === 0)
assert(res.generatedRows === numRows :: Nil)
assert(res.outputRows === numSelectedRows :: numRows :: Nil)
}
test("Input/output/generated metrics with repartitioning") {
val numRows = 100L
val res = MetricsTestHelper.runAndGetMetrics(
spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF())
assert(res.recordsRead.sum === 0)
assert(res.shuffleRecordsRead.sum === numRows)
assert(res.generatedRows === numRows :: Nil)
assert(res.outputRows === 20 :: numRows :: Nil)
}
test("Input/output/generated metrics with more repartitioning") {
withTempDir { tempDir =>
val dir = new File(tempDir, "pqS").getCanonicalPath
spark.range(10).write.parquet(dir)
spark.read.parquet(dir).createOrReplaceTempView("pqS")
val res = MetricsTestHelper.runAndGetMetrics(
spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2)
.toDF()
)
assert(res.recordsRead.sum == 10)
assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150)
assert(res.generatedRows == 30 :: Nil)
assert(res.outputRows == 10 :: 30 :: 300 :: Nil)
}
}
}
object MetricsTestHelper {
case class AggregatedMetricsResult(
recordsRead: List[Long],
shuffleRecordsRead: List[Long],
generatedRows: List[Long],
outputRows: List[Long])
private[this] def extractMetricValues(
df: DataFrame,
metricValues: Map[Long, String],
metricName: String): List[Long] = {
df.queryExecution.executedPlan.collect {
case plan if plan.metrics.contains(metricName) =>
metricValues(plan.metrics(metricName).id).toLong
}.toList.sorted
}
def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false):
AggregatedMetricsResult = {
val spark = df.sparkSession
val sparkContext = spark.sparkContext
var recordsRead = List[Long]()
var shuffleRecordsRead = List[Long]()
val listener = new SparkListener() {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
if (taskEnd.taskMetrics != null) {
recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead ::
recordsRead
shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead ::
shuffleRecordsRead
}
}
}
val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet
val prevUseWholeStageCodeGen =
spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED)
try {
spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen)
sparkContext.listenerBus.waitUntilEmpty(10000)
sparkContext.addSparkListener(listener)
df.collect()
sparkContext.listenerBus.waitUntilEmpty(10000)
} finally {
spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen)
}
val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head
val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
val outputRes = extractMetricValues(df, metricValues, "numOutputRows")
val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows")
AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes, outputRes)
}
}

View file

@ -17,7 +17,12 @@
package org.apache.spark.sql.execution.metric
import java.io.File
import scala.collection.mutable.HashMap
import org.apache.spark.SparkFunSuite
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.execution.SparkPlanInfo
@ -309,4 +314,103 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
assert(metricInfoDeser.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER))
}
test("range metrics") {
val res1 = InputOutputMetricsHelper.run(
spark.range(30).filter(x => x % 3 == 0).toDF()
)
assert(res1 === (30L, 0L, 30L) :: Nil)
val res2 = InputOutputMetricsHelper.run(
spark.range(150).repartition(4).filter(x => x < 10).toDF()
)
assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil)
withTempDir { tempDir =>
val dir = new File(tempDir, "pqS").getCanonicalPath
spark.range(10).write.parquet(dir)
spark.read.parquet(dir).createOrReplaceTempView("pqS")
val res3 = InputOutputMetricsHelper.run(
spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF()
)
// The query above is executed in the following stages:
// 1. sql("select * from pqS") => (10, 0, 10)
// 2. range(30) => (30, 0, 30)
// 3. crossJoin(...) of 1. and 2. => (0, 30, 300)
// 4. shuffle & return results => (0, 300, 0)
assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil)
}
}
}
object InputOutputMetricsHelper {
private class InputOutputMetricsListener extends SparkListener {
private case class MetricsResult(
var recordsRead: Long = 0L,
var shuffleRecordsRead: Long = 0L,
var sumMaxOutputRows: Long = 0L)
private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult]
def reset(): Unit = {
stageIdToMetricsResult.clear()
}
/**
* Return a list of recorded metrics aggregated per stage.
*
* The list is sorted in the ascending order on the stageId.
* For each recorded stage, the following tuple is returned:
* - sum of inputMetrics.recordsRead for all the tasks in the stage
* - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage
* - sum of the highest values of "number of output rows" metric for all the tasks in the stage
*/
def getResults(): List[(Long, Long, Long)] = {
stageIdToMetricsResult.keySet.toList.sorted.map { stageId =>
val res = stageIdToMetricsResult(stageId)
(res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)
}
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult())
res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead
var maxOutputRows = 0L
for (accum <- taskEnd.taskMetrics.externalAccums) {
val info = accum.toInfo(Some(accum.value), None)
if (info.name.toString.contains("number of output rows")) {
info.update match {
case Some(n: Number) =>
if (n.longValue() > maxOutputRows) {
maxOutputRows = n.longValue()
}
case _ => // Ignore.
}
}
}
res.sumMaxOutputRows += maxOutputRows
}
}
// Run df.collect() and return aggregated metrics for each stage.
def run(df: DataFrame): List[(Long, Long, Long)] = {
val spark = df.sparkSession
val sparkContext = spark.sparkContext
val listener = new InputOutputMetricsListener()
sparkContext.addSparkListener(listener)
try {
sparkContext.listenerBus.waitUntilEmpty(5000)
listener.reset()
df.collect()
sparkContext.listenerBus.waitUntilEmpty(5000)
} finally {
sparkContext.removeSparkListener(listener)
}
listener.getResults()
}
}

View file

@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
import org.apache.spark.sql.execution.MetricsTestHelper
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@ -917,13 +917,10 @@ class JDBCSuite extends SparkFunSuite
assert(e2.contains("User specified schema not supported with `jdbc`"))
}
test("Input/generated/output metrics on JDBC") {
test("Checking metrics correctness with JDBC") {
val foobarCnt = spark.table("foobar").count()
val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF())
assert(res.recordsRead === foobarCnt :: Nil)
assert(res.shuffleRecordsRead.sum === 0)
assert(res.generatedRows.isEmpty)
assert(res.outputRows === foobarCnt :: Nil)
val res = InputOutputMetricsHelper.run(sql("SELECT * FROM foobar").toDF())
assert(res === (foobarCnt, 0L, foobarCnt) :: Nil)
}
test("SPARK-19318: Connection properties keys should be case-sensitive.") {

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.execution.MetricsTestHelper
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.hive.test.TestHive
/**
@ -49,21 +49,15 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {
createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part")
test("Test input/generated/output metrics") {
test("Checking metrics correctness") {
import TestHive._
val episodesCnt = sql("select * from episodes").count()
val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF())
assert(episodesRes.recordsRead === episodesCnt :: Nil)
assert(episodesRes.shuffleRecordsRead.sum === 0)
assert(episodesRes.generatedRows.isEmpty)
assert(episodesRes.outputRows === episodesCnt :: Nil)
val episodesRes = InputOutputMetricsHelper.run(sql("select * from episodes").toDF())
assert(episodesRes === (episodesCnt, 0L, episodesCnt) :: Nil)
val serdeinsCnt = sql("select * from serdeins").count()
val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF())
assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
assert(serdeinsRes.shuffleRecordsRead.sum === 0)
assert(serdeinsRes.generatedRows.isEmpty)
assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
val serdeinsRes = InputOutputMetricsHelper.run(sql("select * from serdeins").toDF())
assert(serdeinsRes === (serdeinsCnt, 0L, serdeinsCnt) :: Nil)
}
}