[SPARK-9018] [MLLIB] add stopwatches

Add stopwatches for easy instrumentation of MLlib algorithms. This is based on the `TimeTracker` used in decision trees. The distributed version uses Spark accumulator. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #7415 from mengxr/SPARK-9018 and squashes the following commits:

40b4347 [Xiangrui Meng] == -> ===
c477745 [Xiangrui Meng] address Joseph's comments
f981a49 [Xiangrui Meng] add stopwatches
This commit is contained in:
Xiangrui Meng 2015-07-15 21:02:42 -07:00
parent 6960a7938c
commit 73d92b00b9
2 changed files with 260 additions and 0 deletions

View file

@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import scala.collection.mutable
import org.apache.spark.{Accumulator, SparkContext}
/**
* Abstract class for stopwatches.
*/
private[spark] abstract class Stopwatch extends Serializable {
@transient private var running: Boolean = false
private var startTime: Long = _
/**
* Name of the stopwatch.
*/
val name: String
/**
* Starts the stopwatch.
* Throws an exception if the stopwatch is already running.
*/
def start(): Unit = {
assume(!running, "start() called but the stopwatch is already running.")
running = true
startTime = now
}
/**
* Stops the stopwatch and returns the duration of the last session in milliseconds.
* Throws an exception if the stopwatch is not running.
*/
def stop(): Long = {
assume(running, "stop() called but the stopwatch is not running.")
val duration = now - startTime
add(duration)
running = false
duration
}
/**
* Checks whether the stopwatch is running.
*/
def isRunning: Boolean = running
/**
* Returns total elapsed time in milliseconds, not counting the current session if the stopwatch
* is running.
*/
def elapsed(): Long
/**
* Gets the current time in milliseconds.
*/
protected def now: Long = System.currentTimeMillis()
/**
* Adds input duration to total elapsed time.
*/
protected def add(duration: Long): Unit
}
/**
* A local [[Stopwatch]].
*/
private[spark] class LocalStopwatch(override val name: String) extends Stopwatch {
private var elapsedTime: Long = 0L
override def elapsed(): Long = elapsedTime
override protected def add(duration: Long): Unit = {
elapsedTime += duration
}
}
/**
* A distributed [[Stopwatch]] using Spark accumulator.
* @param sc SparkContext
*/
private[spark] class DistributedStopwatch(
sc: SparkContext,
override val name: String) extends Stopwatch {
private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)")
override def elapsed(): Long = elapsedTime.value
override protected def add(duration: Long): Unit = {
elapsedTime += duration
}
}
/**
* A multiple stopwatch that contains local and distributed stopwatches.
* @param sc SparkContext
*/
private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable {
private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty
/**
* Adds a local stopwatch.
* @param name stopwatch name
*/
def addLocal(name: String): this.type = {
require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
stopwatches(name) = new LocalStopwatch(name)
this
}
/**
* Adds a distributed stopwatch.
* @param name stopwatch name
*/
def addDistributed(name: String): this.type = {
require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
stopwatches(name) = new DistributedStopwatch(sc, name)
this
}
/**
* Gets a stopwatch.
* @param name stopwatch name
*/
def apply(name: String): Stopwatch = stopwatches(name)
override def toString: String = {
stopwatches.values.toArray.sortBy(_.name)
.map(c => s" ${c.name}: ${c.elapsed()}ms")
.mkString("{\n", ",\n", "\n}")
}
}

View file

@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
assert(sw.name === "sw")
assert(sw.elapsed() === 0L)
assert(!sw.isRunning)
intercept[AssertionError] {
sw.stop()
}
sw.start()
Thread.sleep(50)
val duration = sw.stop()
assert(duration >= 50 && duration < 100) // using a loose upper bound
val elapsed = sw.elapsed()
assert(elapsed === duration)
sw.start()
Thread.sleep(50)
val duration2 = sw.stop()
assert(duration2 >= 50 && duration2 < 100)
val elapsed2 = sw.elapsed()
assert(elapsed2 === duration + duration2)
sw.start()
assert(sw.isRunning)
intercept[AssertionError] {
sw.start()
}
}
test("LocalStopwatch") {
val sw = new LocalStopwatch("sw")
testStopwatchOnDriver(sw)
}
test("DistributedStopwatch on driver") {
val sw = new DistributedStopwatch(sc, "sw")
testStopwatchOnDriver(sw)
}
test("DistributedStopwatch on executors") {
val sw = new DistributedStopwatch(sc, "sw")
val rdd = sc.parallelize(0 until 4, 4)
rdd.foreach { i =>
sw.start()
Thread.sleep(50)
sw.stop()
}
assert(!sw.isRunning)
val elapsed = sw.elapsed()
assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound
}
test("MultiStopwatch") {
val sw = new MultiStopwatch(sc)
.addLocal("local")
.addDistributed("spark")
assert(sw("local").name === "local")
assert(sw("spark").name === "spark")
intercept[NoSuchElementException] {
sw("some")
}
assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}")
sw("local").start()
sw("spark").start()
Thread.sleep(50)
sw("local").stop()
Thread.sleep(50)
sw("spark").stop()
val localElapsed = sw("local").elapsed()
val sparkElapsed = sw("spark").elapsed()
assert(localElapsed >= 50 && localElapsed < 100)
assert(sparkElapsed >= 100 && sparkElapsed < 200)
assert(sw.toString ===
s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}")
val rdd = sc.parallelize(0 until 4, 4)
rdd.foreach { i =>
sw("local").start()
sw("spark").start()
Thread.sleep(50)
sw("spark").stop()
sw("local").stop()
}
val localElapsed2 = sw("local").elapsed()
assert(localElapsed2 === localElapsed)
val sparkElapsed2 = sw("spark").elapsed()
assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600)
}
}