[SPARK-9018] [MLLIB] add stopwatches
Add stopwatches for easy instrumentation of MLlib algorithms. This is based on the `TimeTracker` used in decision trees. The distributed version uses Spark accumulator. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #7415 from mengxr/SPARK-9018 and squashes the following commits: 40b4347 [Xiangrui Meng] == -> === c477745 [Xiangrui Meng] address Joseph's comments f981a49 [Xiangrui Meng] add stopwatches
This commit is contained in:
parent
6960a7938c
commit
73d92b00b9
151
mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
Normal file
151
mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
Normal file
|
@ -0,0 +1,151 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.util
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.{Accumulator, SparkContext}
|
||||
|
||||
/**
|
||||
* Abstract class for stopwatches.
|
||||
*/
|
||||
private[spark] abstract class Stopwatch extends Serializable {
|
||||
|
||||
@transient private var running: Boolean = false
|
||||
private var startTime: Long = _
|
||||
|
||||
/**
|
||||
* Name of the stopwatch.
|
||||
*/
|
||||
val name: String
|
||||
|
||||
/**
|
||||
* Starts the stopwatch.
|
||||
* Throws an exception if the stopwatch is already running.
|
||||
*/
|
||||
def start(): Unit = {
|
||||
assume(!running, "start() called but the stopwatch is already running.")
|
||||
running = true
|
||||
startTime = now
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops the stopwatch and returns the duration of the last session in milliseconds.
|
||||
* Throws an exception if the stopwatch is not running.
|
||||
*/
|
||||
def stop(): Long = {
|
||||
assume(running, "stop() called but the stopwatch is not running.")
|
||||
val duration = now - startTime
|
||||
add(duration)
|
||||
running = false
|
||||
duration
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether the stopwatch is running.
|
||||
*/
|
||||
def isRunning: Boolean = running
|
||||
|
||||
/**
|
||||
* Returns total elapsed time in milliseconds, not counting the current session if the stopwatch
|
||||
* is running.
|
||||
*/
|
||||
def elapsed(): Long
|
||||
|
||||
/**
|
||||
* Gets the current time in milliseconds.
|
||||
*/
|
||||
protected def now: Long = System.currentTimeMillis()
|
||||
|
||||
/**
|
||||
* Adds input duration to total elapsed time.
|
||||
*/
|
||||
protected def add(duration: Long): Unit
|
||||
}
|
||||
|
||||
/**
|
||||
* A local [[Stopwatch]].
|
||||
*/
|
||||
private[spark] class LocalStopwatch(override val name: String) extends Stopwatch {
|
||||
|
||||
private var elapsedTime: Long = 0L
|
||||
|
||||
override def elapsed(): Long = elapsedTime
|
||||
|
||||
override protected def add(duration: Long): Unit = {
|
||||
elapsedTime += duration
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A distributed [[Stopwatch]] using Spark accumulator.
|
||||
* @param sc SparkContext
|
||||
*/
|
||||
private[spark] class DistributedStopwatch(
|
||||
sc: SparkContext,
|
||||
override val name: String) extends Stopwatch {
|
||||
|
||||
private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)")
|
||||
|
||||
override def elapsed(): Long = elapsedTime.value
|
||||
|
||||
override protected def add(duration: Long): Unit = {
|
||||
elapsedTime += duration
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A multiple stopwatch that contains local and distributed stopwatches.
|
||||
* @param sc SparkContext
|
||||
*/
|
||||
private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable {
|
||||
|
||||
private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty
|
||||
|
||||
/**
|
||||
* Adds a local stopwatch.
|
||||
* @param name stopwatch name
|
||||
*/
|
||||
def addLocal(name: String): this.type = {
|
||||
require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
|
||||
stopwatches(name) = new LocalStopwatch(name)
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a distributed stopwatch.
|
||||
* @param name stopwatch name
|
||||
*/
|
||||
def addDistributed(name: String): this.type = {
|
||||
require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
|
||||
stopwatches(name) = new DistributedStopwatch(sc, name)
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a stopwatch.
|
||||
* @param name stopwatch name
|
||||
*/
|
||||
def apply(name: String): Stopwatch = stopwatches(name)
|
||||
|
||||
override def toString: String = {
|
||||
stopwatches.values.toArray.sortBy(_.name)
|
||||
.map(c => s" ${c.name}: ${c.elapsed()}ms")
|
||||
.mkString("{\n", ",\n", "\n}")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.util
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
|
||||
class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
|
||||
assert(sw.name === "sw")
|
||||
assert(sw.elapsed() === 0L)
|
||||
assert(!sw.isRunning)
|
||||
intercept[AssertionError] {
|
||||
sw.stop()
|
||||
}
|
||||
sw.start()
|
||||
Thread.sleep(50)
|
||||
val duration = sw.stop()
|
||||
assert(duration >= 50 && duration < 100) // using a loose upper bound
|
||||
val elapsed = sw.elapsed()
|
||||
assert(elapsed === duration)
|
||||
sw.start()
|
||||
Thread.sleep(50)
|
||||
val duration2 = sw.stop()
|
||||
assert(duration2 >= 50 && duration2 < 100)
|
||||
val elapsed2 = sw.elapsed()
|
||||
assert(elapsed2 === duration + duration2)
|
||||
sw.start()
|
||||
assert(sw.isRunning)
|
||||
intercept[AssertionError] {
|
||||
sw.start()
|
||||
}
|
||||
}
|
||||
|
||||
test("LocalStopwatch") {
|
||||
val sw = new LocalStopwatch("sw")
|
||||
testStopwatchOnDriver(sw)
|
||||
}
|
||||
|
||||
test("DistributedStopwatch on driver") {
|
||||
val sw = new DistributedStopwatch(sc, "sw")
|
||||
testStopwatchOnDriver(sw)
|
||||
}
|
||||
|
||||
test("DistributedStopwatch on executors") {
|
||||
val sw = new DistributedStopwatch(sc, "sw")
|
||||
val rdd = sc.parallelize(0 until 4, 4)
|
||||
rdd.foreach { i =>
|
||||
sw.start()
|
||||
Thread.sleep(50)
|
||||
sw.stop()
|
||||
}
|
||||
assert(!sw.isRunning)
|
||||
val elapsed = sw.elapsed()
|
||||
assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound
|
||||
}
|
||||
|
||||
test("MultiStopwatch") {
|
||||
val sw = new MultiStopwatch(sc)
|
||||
.addLocal("local")
|
||||
.addDistributed("spark")
|
||||
assert(sw("local").name === "local")
|
||||
assert(sw("spark").name === "spark")
|
||||
intercept[NoSuchElementException] {
|
||||
sw("some")
|
||||
}
|
||||
assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}")
|
||||
sw("local").start()
|
||||
sw("spark").start()
|
||||
Thread.sleep(50)
|
||||
sw("local").stop()
|
||||
Thread.sleep(50)
|
||||
sw("spark").stop()
|
||||
val localElapsed = sw("local").elapsed()
|
||||
val sparkElapsed = sw("spark").elapsed()
|
||||
assert(localElapsed >= 50 && localElapsed < 100)
|
||||
assert(sparkElapsed >= 100 && sparkElapsed < 200)
|
||||
assert(sw.toString ===
|
||||
s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}")
|
||||
val rdd = sc.parallelize(0 until 4, 4)
|
||||
rdd.foreach { i =>
|
||||
sw("local").start()
|
||||
sw("spark").start()
|
||||
Thread.sleep(50)
|
||||
sw("spark").stop()
|
||||
sw("local").stop()
|
||||
}
|
||||
val localElapsed2 = sw("local").elapsed()
|
||||
assert(localElapsed2 === localElapsed)
|
||||
val sparkElapsed2 = sw("spark").elapsed()
|
||||
assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue