diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 428f019b02..979d178c35 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1030,28 +1030,40 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Support function for API backtraces. + * Set the thread-local property for overriding the call sites + * of actions and RDDs. */ - def setCallSite(site: String) { - setLocalProperty("externalCallSite", site) + def setCallSite(shortCallSite: String) { + setLocalProperty(CallSite.SHORT_FORM, shortCallSite) } /** - * Support function for API backtraces. + * Set the thread-local property for overriding the call sites + * of actions and RDDs. + */ + private[spark] def setCallSite(callSite: CallSite) { + setLocalProperty(CallSite.SHORT_FORM, callSite.shortForm) + setLocalProperty(CallSite.LONG_FORM, callSite.longForm) + } + + /** + * Clear the thread-local property for overriding the call sites + * of actions and RDDs. */ def clearCallSite() { - setLocalProperty("externalCallSite", null) + setLocalProperty(CallSite.SHORT_FORM, null) + setLocalProperty(CallSite.LONG_FORM, null) } /** * Capture the current user callsite and return a formatted version for printing. If the user - * has overridden the call site, this will return the user's version. + * has overridden the call site using `setCallSite()`, this will return the user's version. */ private[spark] def getCallSite(): CallSite = { - Option(getLocalProperty("externalCallSite")) match { - case Some(callSite) => CallSite(callSite, longForm = "") - case None => Utils.getCallSite - } + Option(getLocalProperty(CallSite.SHORT_FORM)).map { case shortCallSite => + val longCallSite = Option(getLocalProperty(CallSite.LONG_FORM)).getOrElse("") + CallSite(shortCallSite, longCallSite) + }.getOrElse(Utils.getCallSite()) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a9b905b0d1..0e90caa5c9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.util.Random +import java.util.{Properties, Random} import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer @@ -41,7 +41,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1224,7 +1224,8 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ - @transient private[spark] val creationSite = Utils.getCallSite + @transient private[spark] val creationSite = sc.getCallSite() + private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("") private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ed06384432..2755887fee 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -49,6 +49,11 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) +private[spark] object CallSite { + val SHORT_FORM = "callSite.short" + val LONG_FORM = "callSite.long" +} + /** * Various utility methods used by Spark. */ @@ -859,18 +864,26 @@ private[spark] object Utils extends Logging { } } - /** - * A regular expression to match classes of the "core" Spark API that we want to skip when - * finding the call site of a method. - */ - private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r + /** Default filtering function for finding call sites using `getCallSite`. */ + private def coreExclusionFunction(className: String): Boolean = { + // A regular expression to match classes of the "core" Spark API that we want to skip when + // finding the call site of a method. + val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r + val SCALA_CLASS_REGEX = """^scala""".r + val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined + val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined + // If the class is a Spark internal class or a Scala class, then exclude. + isSparkCoreClass || isScalaClass + } /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. + * + * @param skipClass Function that is used to exclude non-user-code classes. */ - def getCallSite: CallSite = { + def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = { val trace = Thread.currentThread.getStackTrace() .filterNot { ste:StackTraceElement => // When running under some profilers, the current stack trace might contain some bogus @@ -891,7 +904,7 @@ private[spark] object Utils extends Logging { for (el <- trace) { if (insideSpark) { - if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName).isDefined) { + if (skipClass(el.getClassName)) { lastSparkMethod = if (el.getMethodName == "") { // Spark method is a constructor; get its class name el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index f63560dcb5..5a8eef1372 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -35,10 +35,9 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.{ActorSupervisorStrategy, ActorReceiver, Receiver} +import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.MetadataCleaner /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -448,6 +447,7 @@ class StreamingContext private[streaming] ( throw new SparkException("StreamingContext has already been stopped") } validate() + sparkContext.setCallSite(DStream.getCreationSite()) scheduler.start() state = Started } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index e05db236ad..65f7ccd318 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -23,6 +23,7 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.deprecated import scala.collection.mutable.HashMap import scala.reflect.ClassTag +import scala.util.matching.Regex import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.{BlockRDD, RDD} @@ -30,7 +31,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.scheduler.Job -import org.apache.spark.util.MetadataCleaner +import org.apache.spark.util.{CallSite, MetadataCleaner} /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -106,6 +107,9 @@ abstract class DStream[T: ClassTag] ( /** Return the StreamingContext associated with this DStream */ def context = ssc + /* Set the creation call site */ + private[streaming] val creationSite = DStream.getCreationSite() + /** Persist the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { @@ -272,43 +276,41 @@ abstract class DStream[T: ClassTag] ( } /** - * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal - * method that should not be called directly. + * Get the RDD corresponding to the given time; either retrieve it from cache + * or compute-and-cache it. */ private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { - // If this DStream was not initialized (i.e., zeroTime not set), then do it - // If RDD was already generated, then retrieve it from HashMap - generatedRDDs.get(time) match { + // If RDD was already generated, then retrieve it from HashMap, + // or else compute the RDD + generatedRDDs.get(time).orElse { + // Compute the RDD if time is valid (e.g. correct time in a sliding window) + // of RDD generation, else generate nothing. + if (isTimeValid(time)) { + // Set the thread-local property for call sites to this DStream's creation site + // such that RDDs generated by compute gets that as their creation site. + // Note that this `getOrCompute` may get called from another DStream which may have + // set its own call site. So we store its call site in a temporary variable, + // set this DStream's creation site, generate RDDs and then restore the previous call site. + val prevCallSite = ssc.sparkContext.getCallSite() + ssc.sparkContext.setCallSite(creationSite) + val rddOption = compute(time) + ssc.sparkContext.setCallSite(prevCallSite) - // If an RDD was already generated and is being reused, then - // probably all RDDs in this DStream will be reused and hence should be cached - case Some(oldRDD) => Some(oldRDD) - - // if RDD was not generated, and if the time is valid - // (based on sliding time of this DStream), then generate the RDD - case None => { - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => - if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting RDD " + newRDD.id + " for time " + - time + " to " + storageLevel + " at time " + time) - } - if (checkpointDuration != null && - (time - zeroTime).isMultipleOf(checkpointDuration)) { - newRDD.checkpoint() - logInfo("Marking RDD " + newRDD.id + " for time " + time + - " for checkpointing at time " + time) - } - generatedRDDs.put(time, newRDD) - Some(newRDD) - case None => - None + rddOption.foreach { case newRDD => + // Register the generated RDD for caching and checkpointing + if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logDebug(s"Persisting RDD ${newRDD.id} for time $time to $storageLevel") } - } else { - None + if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) { + newRDD.checkpoint() + logInfo(s"Marking RDD ${newRDD.id} for time $time for checkpointing") + } + generatedRDDs.put(time, newRDD) } + rddOption + } else { + None } } } @@ -799,3 +801,29 @@ abstract class DStream[T: ClassTag] ( this } } + +private[streaming] object DStream { + + /** Get the creation site of a DStream from the stack trace of when the DStream is created. */ + def getCreationSite(): CallSite = { + val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r + val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r + val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r + val SCALA_CLASS_REGEX = """^scala""".r + + /** Filtering function that excludes non-user classes for a streaming application */ + def streamingExclustionFunction(className: String): Boolean = { + def doesMatch(r: Regex) = r.findFirstIn(className).isDefined + val isSparkClass = doesMatch(SPARK_CLASS_REGEX) + val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX) + val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX) + val isScalaClass = doesMatch(SCALA_CLASS_REGEX) + + // If the class is a spark example class or a streaming test class then it is considered + // as a streaming application class and don't exclude. Otherwise, exclude any + // non-Spark and non-Scala class, as the rest would streaming application classes. + (isSparkClass || isScalaClass) && !isSparkExampleClass && !isSparkStreamingTestClass + } + org.apache.spark.util.Utils.getCallSite(streamingExclustionFunction) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index a3cabd6be0..ebf83748ff 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -19,13 +19,16 @@ package org.apache.spark.streaming import java.util.concurrent.atomic.AtomicInteger +import scala.language.postfixOps + import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.util.{MetadataCleaner, Utils} -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.util.Utils +import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.Eventually._ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ @@ -257,6 +260,10 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w assert(exception.getMessage.contains("transform"), "Expected exception not thrown") } + test("DStream and generated RDD creation sites") { + testPackage.test() + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => (1 to i)) val inputStream = new TestInputStream(s, input, 1) @@ -293,3 +300,37 @@ class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging object TestReceiver { val counter = new AtomicInteger(1) } + +/** Streaming application for testing DStream and RDD creation sites */ +package object testPackage extends Assertions { + def test() { + val conf = new SparkConf().setMaster("local").setAppName("CreationSite test") + val ssc = new StreamingContext(conf , Milliseconds(100)) + try { + val inputStream = ssc.receiverStream(new TestReceiver) + + // Verify creation site of DStream + val creationSite = inputStream.creationSite + assert(creationSite.shortForm.contains("receiverStream") && + creationSite.shortForm.contains("StreamingContextSuite") + ) + assert(creationSite.longForm.contains("testPackage")) + + // Verify creation site of generated RDDs + var rddGenerated = false + var rddCreationSiteCorrect = true + + inputStream.foreachRDD { rdd => + rddCreationSiteCorrect = rdd.creationSite == creationSite + rddGenerated = true + } + ssc.start() + + eventually(timeout(10000 millis), interval(10 millis)) { + assert(rddGenerated && rddCreationSiteCorrect, "RDD creation site was not correct") + } + } finally { + ssc.stop() + } + } +}