From 73d92b00b9a6f5dfc2f8116447d17b381cd74f80 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 15 Jul 2015 21:02:42 -0700 Subject: [PATCH] [SPARK-9018] [MLLIB] add stopwatches Add stopwatches for easy instrumentation of MLlib algorithms. This is based on the `TimeTracker` used in decision trees. The distributed version uses Spark accumulator. jkbradley Author: Xiangrui Meng Closes #7415 from mengxr/SPARK-9018 and squashes the following commits: 40b4347 [Xiangrui Meng] == -> === c477745 [Xiangrui Meng] address Joseph's comments f981a49 [Xiangrui Meng] add stopwatches --- .../apache/spark/ml/util/stopwatches.scala | 151 ++++++++++++++++++ .../apache/spark/ml/util/StopwatchSuite.scala | 109 +++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala new file mode 100644 index 0000000000..5fdf878a3d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.{Accumulator, SparkContext} + +/** + * Abstract class for stopwatches. + */ +private[spark] abstract class Stopwatch extends Serializable { + + @transient private var running: Boolean = false + private var startTime: Long = _ + + /** + * Name of the stopwatch. + */ + val name: String + + /** + * Starts the stopwatch. + * Throws an exception if the stopwatch is already running. + */ + def start(): Unit = { + assume(!running, "start() called but the stopwatch is already running.") + running = true + startTime = now + } + + /** + * Stops the stopwatch and returns the duration of the last session in milliseconds. + * Throws an exception if the stopwatch is not running. + */ + def stop(): Long = { + assume(running, "stop() called but the stopwatch is not running.") + val duration = now - startTime + add(duration) + running = false + duration + } + + /** + * Checks whether the stopwatch is running. + */ + def isRunning: Boolean = running + + /** + * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch + * is running. + */ + def elapsed(): Long + + /** + * Gets the current time in milliseconds. + */ + protected def now: Long = System.currentTimeMillis() + + /** + * Adds input duration to total elapsed time. + */ + protected def add(duration: Long): Unit +} + +/** + * A local [[Stopwatch]]. + */ +private[spark] class LocalStopwatch(override val name: String) extends Stopwatch { + + private var elapsedTime: Long = 0L + + override def elapsed(): Long = elapsedTime + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A distributed [[Stopwatch]] using Spark accumulator. + * @param sc SparkContext + */ +private[spark] class DistributedStopwatch( + sc: SparkContext, + override val name: String) extends Stopwatch { + + private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)") + + override def elapsed(): Long = elapsedTime.value + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A multiple stopwatch that contains local and distributed stopwatches. + * @param sc SparkContext + */ +private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable { + + private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty + + /** + * Adds a local stopwatch. + * @param name stopwatch name + */ + def addLocal(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new LocalStopwatch(name) + this + } + + /** + * Adds a distributed stopwatch. + * @param name stopwatch name + */ + def addDistributed(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new DistributedStopwatch(sc, name) + this + } + + /** + * Gets a stopwatch. + * @param name stopwatch name + */ + def apply(name: String): Stopwatch = stopwatches(name) + + override def toString: String = { + stopwatches.values.toArray.sortBy(_.name) + .map(c => s" ${c.name}: ${c.elapsed()}ms") + .mkString("{\n", ",\n", "\n}") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala new file mode 100644 index 0000000000..8df6617fe0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def testStopwatchOnDriver(sw: Stopwatch): Unit = { + assert(sw.name === "sw") + assert(sw.elapsed() === 0L) + assert(!sw.isRunning) + intercept[AssertionError] { + sw.stop() + } + sw.start() + Thread.sleep(50) + val duration = sw.stop() + assert(duration >= 50 && duration < 100) // using a loose upper bound + val elapsed = sw.elapsed() + assert(elapsed === duration) + sw.start() + Thread.sleep(50) + val duration2 = sw.stop() + assert(duration2 >= 50 && duration2 < 100) + val elapsed2 = sw.elapsed() + assert(elapsed2 === duration + duration2) + sw.start() + assert(sw.isRunning) + intercept[AssertionError] { + sw.start() + } + } + + test("LocalStopwatch") { + val sw = new LocalStopwatch("sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on driver") { + val sw = new DistributedStopwatch(sc, "sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on executors") { + val sw = new DistributedStopwatch(sc, "sw") + val rdd = sc.parallelize(0 until 4, 4) + rdd.foreach { i => + sw.start() + Thread.sleep(50) + sw.stop() + } + assert(!sw.isRunning) + val elapsed = sw.elapsed() + assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound + } + + test("MultiStopwatch") { + val sw = new MultiStopwatch(sc) + .addLocal("local") + .addDistributed("spark") + assert(sw("local").name === "local") + assert(sw("spark").name === "spark") + intercept[NoSuchElementException] { + sw("some") + } + assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") + sw("local").start() + sw("spark").start() + Thread.sleep(50) + sw("local").stop() + Thread.sleep(50) + sw("spark").stop() + val localElapsed = sw("local").elapsed() + val sparkElapsed = sw("spark").elapsed() + assert(localElapsed >= 50 && localElapsed < 100) + assert(sparkElapsed >= 100 && sparkElapsed < 200) + assert(sw.toString === + s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") + val rdd = sc.parallelize(0 until 4, 4) + rdd.foreach { i => + sw("local").start() + sw("spark").start() + Thread.sleep(50) + sw("spark").stop() + sw("local").stop() + } + val localElapsed2 = sw("local").elapsed() + assert(localElapsed2 === localElapsed) + val sparkElapsed2 = sw("spark").elapsed() + assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600) + } +}