[SPARK-27323][CORE][SQL][STREAMING] Use Single-Abstract-Method support in Scala 2.12 to simplify code

## What changes were proposed in this pull request?

Use Single Abstract Method syntax where possible (and minor related cleanup). Comments below. No logic should change here.

## How was this patch tested?

Existing tests.

Closes #24241 from srowen/SPARK-27323.

Authored-by: Sean Owen <sean.owen@databricks.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Sean Owen 2019-04-02 07:37:05 -07:00 committed by Dongjoon Hyun
parent d575a453db
commit d4420b455a
78 changed files with 542 additions and 848 deletions

View file

@ -19,7 +19,7 @@ package org.apache.spark
import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.{Consumer, Function}
import java.util.function.Consumer
import scala.collection.mutable.ArrayBuffer
@ -202,10 +202,8 @@ private[spark] class BarrierCoordinator(
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
// Get or init the ContextBarrierState correspond to the stage attempt.
val barrierId = ContextBarrierId(stageId, stageAttemptId)
states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] {
override def apply(key: ContextBarrierId): ContextBarrierState =
new ContextBarrierState(key, numTasks)
})
states.computeIfAbsent(barrierId,
(key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
val barrierState = states.get(barrierId)
barrierState.handleRequest(context, request)

View file

@ -123,9 +123,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
cleaningThread.setDaemon(true)
cleaningThread.setName("Spark Context Cleaner")
cleaningThread.start()
periodicGCService.scheduleAtFixedRate(new Runnable {
override def run(): Unit = System.gc()
}, periodicGCInterval, periodicGCInterval, TimeUnit.SECONDS)
periodicGCService.scheduleAtFixedRate(() => System.gc(),
periodicGCInterval, periodicGCInterval, TimeUnit.SECONDS)
}
/**

View file

@ -98,11 +98,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread")
override def onStart(): Unit = {
timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
Option(self).foreach(_.ask[Boolean](ExpireDeadHosts))
}
}, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(
() => Utils.tryLogNonFatalError { Option(self).foreach(_.ask[Boolean](ExpireDeadHosts)) },
0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {

View file

@ -62,9 +62,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
@transient private lazy val reader: ConfigReader = {
val _reader = new ConfigReader(new SparkConfigProvider(settings))
_reader.bindEnv(new ConfigProvider {
override def get(key: String): Option[String] = Option(getenv(key))
})
_reader.bindEnv((key: String) => Option(getenv(key)))
_reader
}
@ -392,7 +390,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
/** Get an optional value, applying variable substitution. */
private[spark] def getWithSubstitution(key: String): Option[String] = {
getOption(key).map(reader.substitute(_))
getOption(key).map(reader.substitute)
}
/** Get all parameters as a list of pairs */

View file

@ -60,11 +60,7 @@ object PythonRunner {
.javaAddress(localhost)
.callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
.build()
val thread = new Thread(new Runnable() {
override def run(): Unit = Utils.logUncaughtExceptions {
gatewayServer.start()
}
})
val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.start() })
thread.setName("py4j-gateway-init")
thread.setDaemon(true)
thread.start()

View file

@ -20,7 +20,7 @@ package org.apache.spark.deploy
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException}
import java.security.PrivilegedExceptionAction
import java.text.DateFormat
import java.util.{Arrays, Comparator, Date, Locale}
import java.util.{Arrays, Date, Locale}
import scala.collection.JavaConverters._
import scala.collection.immutable.Map
@ -270,11 +270,8 @@ private[spark] class SparkHadoopUtil extends Logging {
name.startsWith(prefix) && !name.endsWith(exclusionSuffix)
}
})
Arrays.sort(fileStatuses, new Comparator[FileStatus] {
override def compare(o1: FileStatus, o2: FileStatus): Int = {
Longs.compare(o1.getModificationTime, o2.getModificationTime)
}
})
Arrays.sort(fileStatuses, (o1: FileStatus, o2: FileStatus) =>
Longs.compare(o1.getModificationTime, o2.getModificationTime))
fileStatuses
} catch {
case NonFatal(e) =>
@ -465,7 +462,7 @@ private[spark] object SparkHadoopUtil {
// scalastyle:on line.size.limit
def createNonECFile(fs: FileSystem, path: Path): FSDataOutputStream = {
try {
// Use reflection as this uses apis only avialable in hadoop 3
// Use reflection as this uses APIs only available in Hadoop 3
val builderMethod = fs.getClass().getMethod("createFile", classOf[Path])
// the builder api does not resolve relative paths, nor does it create parent dirs, while
// the old api does.

View file

@ -186,13 +186,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
* Return a runnable that performs the given operation on the event logs.
* This operation is expected to be executed periodically.
*/
private def getRunner(operateFun: () => Unit): Runnable = {
new Runnable() {
override def run(): Unit = Utils.tryOrExit {
operateFun()
}
}
}
private def getRunner(operateFun: () => Unit): Runnable =
() => Utils.tryOrExit { operateFun() }
/**
* Fixed size thread pool to fetch and parse log files.
@ -221,29 +216,25 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
// Cannot probe anything while the FS is in safe mode, so spawn a new thread that will wait
// for the FS to leave safe mode before enabling polling. This allows the main history server
// UI to be shown (so that the user can see the HDFS status).
val initThread = new Thread(new Runnable() {
override def run(): Unit = {
try {
while (isFsInSafeMode()) {
logInfo("HDFS is still in safe mode. Waiting...")
val deadline = clock.getTimeMillis() +
TimeUnit.SECONDS.toMillis(SAFEMODE_CHECK_INTERVAL_S)
clock.waitTillTime(deadline)
}
startPolling()
} catch {
case _: InterruptedException =>
val initThread = new Thread(() => {
try {
while (isFsInSafeMode()) {
logInfo("HDFS is still in safe mode. Waiting...")
val deadline = clock.getTimeMillis() +
TimeUnit.SECONDS.toMillis(SAFEMODE_CHECK_INTERVAL_S)
clock.waitTillTime(deadline)
}
startPolling()
} catch {
case _: InterruptedException =>
}
})
initThread.setDaemon(true)
initThread.setName(s"${getClass().getSimpleName()}-init")
initThread.setUncaughtExceptionHandler(errorHandler.getOrElse(
new Thread.UncaughtExceptionHandler() {
override def uncaughtException(t: Thread, e: Throwable): Unit = {
logError("Error initializing FsHistoryProvider.", e)
System.exit(1)
}
(_: Thread, e: Throwable) => {
logError("Error initializing FsHistoryProvider.", e)
System.exit(1)
}))
initThread.start()
initThread
@ -517,9 +508,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
val tasks = updated.flatMap { entry =>
try {
val task: Future[Unit] = replayExecutor.submit(new Runnable {
override def run(): Unit = mergeApplicationListing(entry, newLastScanTime, true)
}, Unit)
val task: Future[Unit] = replayExecutor.submit(
() => mergeApplicationListing(entry, newLastScanTime, true))
Some(task -> entry.getPath)
} catch {
// let the iteration over the updated entries break, since an exception on

View file

@ -150,11 +150,9 @@ private[deploy] class Master(
logInfo(s"Spark Master is acting as a reverse proxy. Master, Workers and " +
s"Applications UIs are available at $masterWebUiUrl")
}
checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
self.send(CheckForWorkerTimeOut)
}
}, 0, workerTimeoutMs, TimeUnit.MILLISECONDS)
checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(
() => Utils.tryLogNonFatalError { self.send(CheckForWorkerTimeOut) },
0, workerTimeoutMs, TimeUnit.MILLISECONDS)
if (restServerEnabled) {
val port = conf.get(MASTER_REST_SERVER_PORT)

View file

@ -325,11 +325,9 @@ private[deploy] class Worker(
if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
registrationRetryTimer.foreach(_.cancel(true))
registrationRetryTimer = Some(
forwardMessageScheduler.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
self.send(ReregisterWithMaster)
}
}, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS,
forwardMessageScheduler.scheduleAtFixedRate(
() => Utils.tryLogNonFatalError { self.send(ReregisterWithMaster) },
PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS,
PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS,
TimeUnit.SECONDS))
}
@ -341,7 +339,7 @@ private[deploy] class Worker(
}
/**
* Cancel last registeration retry, or do nothing if no retry
* Cancel last registration retry, or do nothing if no retry
*/
private def cancelLastRegistrationRetry(): Unit = {
if (registerMasterFutures != null) {
@ -361,11 +359,7 @@ private[deploy] class Worker(
registerMasterFutures = tryRegisterAllMasters()
connectionAttemptCount = 0
registrationRetryTimer = Some(forwardMessageScheduler.scheduleAtFixedRate(
new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
Option(self).foreach(_.send(ReregisterWithMaster))
}
},
() => Utils.tryLogNonFatalError { Option(self).foreach(_.send(ReregisterWithMaster)) },
INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
TimeUnit.SECONDS))
@ -407,19 +401,15 @@ private[deploy] class Worker(
}
registered = true
changeMaster(masterRef, masterWebUiUrl, masterAddress)
forwardMessageScheduler.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
self.send(SendHeartbeat)
}
}, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)
forwardMessageScheduler.scheduleAtFixedRate(
() => Utils.tryLogNonFatalError { self.send(SendHeartbeat) },
0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)
if (CLEANUP_ENABLED) {
logInfo(
s"Worker cleanup enabled; old application directories will be deleted in: $workDir")
forwardMessageScheduler.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
self.send(WorkDirCleanup)
}
}, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
forwardMessageScheduler.scheduleAtFixedRate(
() => Utils.tryLogNonFatalError { self.send(WorkDirCleanup) },
CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
}
val execs = executors.values.map { e =>
@ -568,7 +558,7 @@ private[deploy] class Worker(
}
}
case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
case executorStateChanged: ExecutorStateChanged =>
handleExecutorStateChanged(executorStateChanged)
case KillExecutor(masterUrl, appId, execId) =>
@ -632,7 +622,7 @@ private[deploy] class Worker(
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
if (master.exists(_.address == remoteAddress) ||
masterAddressToConnect.exists(_ == remoteAddress)) {
masterAddressToConnect.contains(remoteAddress)) {
logInfo(s"$remoteAddress Disassociated !")
masterDisconnected()
}
@ -815,7 +805,7 @@ private[deploy] object Worker extends Logging {
val systemName = SYSTEM_NAME + workerNumber.map(_.toString).getOrElse("")
val securityMgr = new SecurityManager(conf)
val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr)
val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_))
val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL)
rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory,
masterAddresses, ENDPOINT_NAME, workDir, conf, securityMgr))
rpcEnv

View file

@ -89,17 +89,14 @@ private[spark] class Executor(
}
// Start worker thread pool
// Use UninterruptibleThread to run tasks so that we can allow running codes without being
// interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
// will hang forever if some methods are interrupted.
private val threadPool = {
val threadFactory = new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("Executor task launch worker-%d")
.setThreadFactory(new ThreadFactory {
override def newThread(r: Runnable): Thread =
// Use UninterruptibleThread to run tasks so that we can allow running codes without being
// interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
// will hang forever if some methods are interrupted.
new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder
})
.setThreadFactory((r: Runnable) => new UninterruptibleThread(r, "unused"))
.build()
Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
}

View file

@ -44,7 +44,7 @@ private[spark] abstract class LauncherBackend {
.map(_.toInt)
val secret = conf.getOption(LauncherProtocol.CONF_LAUNCHER_SECRET)
.orElse(sys.env.get(LauncherProtocol.ENV_LAUNCHER_SECRET))
if (port != None && secret != None) {
if (port.isDefined && secret.isDefined) {
val s = new Socket(InetAddress.getLoopbackAddress(), port.get)
connection = new BackendConnection(s)
connection.send(new Hello(secret.get, SPARK_VERSION))
@ -94,11 +94,8 @@ private[spark] abstract class LauncherBackend {
protected def onDisconnected() : Unit = { }
private def fireStopRequest(): Unit = {
val thread = LauncherBackend.threadFactory.newThread(new Runnable() {
override def run(): Unit = Utils.tryLogNonFatalError {
onStopRequest()
}
})
val thread = LauncherBackend.threadFactory.newThread(
() => Utils.tryLogNonFatalError { onStopRequest() })
thread.start()
}

View file

@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit
import scala.collection.mutable
import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
import com.codahale.metrics.{Metric, MetricRegistry}
import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{SecurityManager, SparkConf}
@ -168,9 +168,7 @@ private[spark] class MetricsSystem private (
def removeSource(source: Source) {
sources -= source
val regName = buildRegistryName(source)
registry.removeMatching(new MetricFilter {
def matches(name: String, metric: Metric): Boolean = name.startsWith(regName)
})
registry.removeMatching((name: String, _: Metric) => name.startsWith(regName))
}
private def registerSources() {

View file

@ -21,7 +21,6 @@ import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import java.util.function.BiFunction
import scala.annotation.tailrec
import scala.collection.Map
@ -370,9 +369,10 @@ private[spark] class DAGScheduler(
* 2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)).
*/
private def checkBarrierStageWithRDDChainPattern(rdd: RDD[_], numTasksInStage: Int): Unit = {
val predicate: RDD[_] => Boolean = (r =>
r.getNumPartitions == numTasksInStage && r.dependencies.filter(_.rdd.isBarrier()).size <= 1)
if (rdd.isBarrier() && !traverseParentRDDsWithinStage(rdd, predicate)) {
if (rdd.isBarrier() &&
!traverseParentRDDsWithinStage(rdd, (r: RDD[_]) =>
r.getNumPartitions == numTasksInStage &&
r.dependencies.count(_.rdd.isBarrier()) <= 1)) {
throw new BarrierJobUnsupportedRDDChainException
}
}
@ -692,7 +692,7 @@ private[spark] class DAGScheduler(
}
val jobId = nextJobId.getAndIncrement()
if (partitions.size == 0) {
if (partitions.isEmpty) {
val time = clock.getTimeMillis()
listenerBus.post(
SparkListenerJobStart(jobId, time, Seq[StageInfo](), properties))
@ -702,9 +702,9 @@ private[spark] class DAGScheduler(
return new JobWaiter[U](this, jobId, 0, resultHandler)
}
assert(partitions.size > 0)
assert(partitions.nonEmpty)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions.toArray, callSite, waiter,
SerializationUtils.clone(properties)))
@ -767,9 +767,8 @@ private[spark] class DAGScheduler(
callSite: CallSite,
timeout: Long,
properties: Properties): PartialResult[R] = {
val partitions = (0 until rdd.partitions.length).toArray
val jobId = nextJobId.getAndIncrement()
if (partitions.isEmpty) {
if (rdd.partitions.isEmpty) {
// Return immediately if the job is running 0 tasks
val time = clock.getTimeMillis()
listenerBus.post(SparkListenerJobStart(jobId, time, Seq[StageInfo](), properties))
@ -779,7 +778,8 @@ private[spark] class DAGScheduler(
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties)))
jobId, rdd, func2, rdd.partitions.indices.toArray, callSite, listener,
SerializationUtils.clone(properties)))
listener.awaitResult() // Will throw an exception if the job fails
}
@ -812,7 +812,9 @@ private[spark] class DAGScheduler(
// This makes it easier to avoid race conditions between the user code and the map output
// tracker that might result if we told the user the stage had finished, but then they queries
// the map output tracker and some node failures had caused the output statistics to be lost.
val waiter = new JobWaiter(this, jobId, 1, (i: Int, r: MapOutputStatistics) => callback(r))
val waiter = new JobWaiter[MapOutputStatistics](
this, jobId, 1,
(_: Int, r: MapOutputStatistics) => callback(r))
eventProcessLoop.post(MapStageSubmitted(
jobId, dependency, callSite, waiter, SerializationUtils.clone(properties)))
waiter
@ -870,7 +872,7 @@ private[spark] class DAGScheduler(
* the last fetch failure.
*/
private[scheduler] def resubmitFailedStages() {
if (failedStages.size > 0) {
if (failedStages.nonEmpty) {
// Failed stages may be removed by job cancellation, so failed might be empty even if
// the ResubmitFailedStages event has been scheduled.
logInfo("Resubmitting failed stages")
@ -982,9 +984,7 @@ private[spark] class DAGScheduler(
"than the total number of slots in the cluster currently.")
// If jobId doesn't exist in the map, Scala coverts its value null to 0: Int automatically.
val numCheckFailures = barrierJobIdToNumTasksCheckFailures.compute(jobId,
new BiFunction[Int, Int, Int] {
override def apply(key: Int, value: Int): Int = value + 1
})
(_: Int, value: Int) => value + 1)
if (numCheckFailures <= maxFailureNumTasksCheck) {
messageScheduler.schedule(
new Runnable {
@ -1227,7 +1227,7 @@ private[spark] class DAGScheduler(
return
}
if (tasks.size > 0) {
if (tasks.nonEmpty) {
logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
taskScheduler.submitTasks(new TaskSet(
@ -1942,7 +1942,7 @@ private[spark] class DAGScheduler(
job: ActiveJob,
failureReason: String,
exception: Option[Throwable] = None): Unit = {
val error = new SparkException(failureReason, exception.getOrElse(null))
val error = new SparkException(failureReason, exception.orNull)
var ableToCancelStages = true
// Cancel all independent, running stages.

View file

@ -80,7 +80,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
logDebug("Fetching indirect task result for TID %s".format(tid))
scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
if (!serializedTaskResult.isDefined) {
if (serializedTaskResult.isEmpty) {
/* We won't be able to get the task result if the machine that ran the task failed
* between when the task ended and when we tried to fetch the result, or if the
* block manager had to flush the result. */
@ -128,27 +128,25 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
serializedData: ByteBuffer) {
var reason : TaskFailedReason = UnknownReason
try {
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
val loader = Utils.getContextOrSparkClassLoader
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskFailedReason](
serializedData, loader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastrophic
// if we can't deserialize the reason.
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Exception => // No-op
} finally {
// If there's an error while deserializing the TaskEndReason, this Runnable
// will die. Still tell the scheduler about the task failure, to avoid a hang
// where the scheduler thinks the task is still running.
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
val loader = Utils.getContextOrSparkClassLoader
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskFailedReason](
serializedData, loader)
}
} catch {
case _: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastrophic
// if we can't deserialize the reason.
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case _: Exception => // No-op
} finally {
// If there's an error while deserializing the TaskEndReason, this Runnable
// will die. Still tell the scheduler about the task failure, to avoid a hang
// where the scheduler thinks the task is still running.
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
})
} catch {

View file

@ -53,7 +53,7 @@ import org.apache.spark.util.{AccumulatorV2, SystemClock, ThreadUtils, Utils}
* we are holding a lock on ourselves. This class is called from many threads, notably:
* * The DAGScheduler Event Loop
* * The RPCHandler threads, responding to status updates from Executors
* * Periodic revival of all offers from the CoarseGrainedSchedulerBackend, to accomodate delay
* * Periodic revival of all offers from the CoarseGrainedSchedulerBackend, to accommodate delay
* scheduling
* * task-result-getter threads
*/
@ -194,11 +194,9 @@ private[spark] class TaskSchedulerImpl(
if (!isLocal && conf.get(SPECULATION_ENABLED)) {
logInfo("Starting speculative execution thread")
speculationScheduler.scheduleWithFixedDelay(new Runnable {
override def run(): Unit = Utils.tryOrStopSparkContext(sc) {
checkSpeculatableTasks()
}
}, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS)
speculationScheduler.scheduleWithFixedDelay(
() => Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() },
SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS)
}
}
@ -373,7 +371,7 @@ private[spark] class TaskSchedulerImpl(
}
}
}
return launchedTask
launchedTask
}
/**
@ -527,7 +525,7 @@ private[spark] class TaskSchedulerImpl(
// TODO SPARK-24823 Cancel a job that contains barrier stage(s) if the barrier tasks don't get
// launched within a configured time.
if (tasks.size > 0) {
if (tasks.nonEmpty) {
hasLaunchedTask = true
}
return tasks

View file

@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.collection.mutable.{HashMap, HashSet}
import scala.concurrent.Future
import org.apache.hadoop.security.UserGroupInformation
@ -133,10 +133,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Periodically revive offers to allow delay scheduling to work
val reviveIntervalMs = conf.get(SCHEDULER_REVIVE_INTERVAL).getOrElse(1000L)
reviveThread.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
Option(self).foreach(_.send(ReviveOffers))
}
reviveThread.scheduleAtFixedRate(() => Utils.tryLogNonFatalError {
Option(self).foreach(_.send(ReviveOffers))
}, 0, reviveIntervalMs, TimeUnit.MILLISECONDS)
}
@ -268,7 +266,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}.toIndexedSeq
scheduler.resourceOffers(workOffers)
}
if (!taskDescs.isEmpty) {
if (taskDescs.nonEmpty) {
launchTasks(taskDescs)
}
}
@ -296,7 +294,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
Seq.empty
}
}
if (!taskDescs.isEmpty) {
if (taskDescs.nonEmpty) {
launchTasks(taskDescs)
}
}
@ -669,7 +667,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
val killExecutors: Boolean => Future[Boolean] =
if (!executorsToKill.isEmpty) {
if (executorsToKill.nonEmpty) {
_ => doKillExecutors(executorsToKill)
} else {
_ => Future.successful(false)

View file

@ -19,7 +19,6 @@ package org.apache.spark.status
import java.util.Date
import java.util.concurrent.ConcurrentHashMap
import java.util.function.Function
import scala.collection.JavaConverters._
import scala.collection.mutable.HashMap
@ -840,11 +839,11 @@ private[spark] class AppStatusListener(
// check if there is a new peak value for any of the executor level memory metrics,
// while reading from the log. SparkListenerStageExecutorMetrics are only processed
// when reading logs.
liveExecutors.get(executorMetrics.execId)
.orElse(deadExecutors.get(executorMetrics.execId)).map { exec =>
if (exec.peakExecutorMetrics.compareAndUpdatePeakValues(executorMetrics.executorMetrics)) {
update(exec, now)
}
liveExecutors.get(executorMetrics.execId).orElse(
deadExecutors.get(executorMetrics.execId)).foreach { exec =>
if (exec.peakExecutorMetrics.compareAndUpdatePeakValues(executorMetrics.executorMetrics)) {
update(exec, now)
}
}
}
@ -1048,9 +1047,7 @@ private[spark] class AppStatusListener(
private def getOrCreateStage(info: StageInfo): LiveStage = {
val stage = liveStages.computeIfAbsent((info.stageId, info.attemptNumber),
new Function[(Int, Int), LiveStage]() {
override def apply(key: (Int, Int)): LiveStage = new LiveStage()
})
(_: (Int, Int)) => new LiveStage())
stage.info = info
stage
}

View file

@ -143,12 +143,10 @@ private[spark] class ExternalSorter[K, V, C](
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
// non-equal keys also have this, so we need to do a later pass to find truly equal keys).
// Note that we ignore this if no aggregator and no ordering are given.
private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
override def compare(a: K, b: K): Int = {
val h1 = if (a == null) 0 else a.hashCode()
val h2 = if (b == null) 0 else b.hashCode()
if (h1 < h2) -1 else if (h1 == h2) 0 else 1
}
private val keyComparator: Comparator[K] = ordering.getOrElse((a: K, b: K) => {
val h1 = if (a == null) 0 else a.hashCode()
val h2 = if (b == null) 0 else b.hashCode()
if (h1 < h2) -1 else if (h1 == h2) 0 else 1
})
private def comparator: Option[Comparator[K]] = {
@ -363,17 +361,15 @@ private[spark] class ExternalSorter[K, V, C](
* Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys.
*/
private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
: Iterator[Product2[K, C]] =
{
: Iterator[Product2[K, C]] = {
val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
type Iter = BufferedIterator[Product2[K, C]]
val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
// Use the reverse order because PriorityQueue dequeues the max
override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1)
})
// Use the reverse order (compare(y,x)) because PriorityQueue dequeues the max
val heap = new mutable.PriorityQueue[Iter]()(
(x: Iter, y: Iter) => comparator.compare(y.head._1, x.head._1))
heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true
new Iterator[Product2[K, C]] {
override def hasNext: Boolean = !heap.isEmpty
override def hasNext: Boolean = heap.nonEmpty
override def next(): Product2[K, C] = {
if (!hasNext) {
@ -400,13 +396,12 @@ private[spark] class ExternalSorter[K, V, C](
mergeCombiners: (C, C) => C,
comparator: Comparator[K],
totalOrder: Boolean)
: Iterator[Product2[K, C]] =
{
: Iterator[Product2[K, C]] = {
if (!totalOrder) {
// We only have a partial ordering, e.g. comparing the keys by hash code, which means that
// multiple distinct keys might be treated as equal by the ordering. To deal with this, we
// need to read all keys considered equal by the ordering at once and compare them.
new Iterator[Iterator[Product2[K, C]]] {
val it = new Iterator[Iterator[Product2[K, C]]] {
val sorted = mergeSort(iterators, comparator).buffered
// Buffers reused across elements to decrease memory allocation
@ -446,7 +441,8 @@ private[spark] class ExternalSorter[K, V, C](
// equal by the partial order; we flatten this below to get a flat iterator of (K, C).
keys.iterator.zip(combiners.iterator)
}
}.flatMap(i => i)
}
it.flatten
} else {
// We have a total ordering, so the objects with the same key are sequential.
new Iterator[Product2[K, C]] {
@ -650,7 +646,7 @@ private[spark] class ExternalSorter[K, V, C](
if (spills.isEmpty) {
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
// we don't even need to sort by anything other than partition ID
if (!ordering.isDefined) {
if (ordering.isEmpty) {
// The user hasn't requested sorted keys, so only sort by partition ID, not key
groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
} else {

View file

@ -68,27 +68,20 @@ private[spark] object WritablePartitionedPairCollection {
/**
* A comparator for (Int, K) pairs that orders them by only their partition ID.
*/
def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] {
override def compare(a: (Int, K), b: (Int, K)): Int = {
a._1 - b._1
}
}
def partitionComparator[K]: Comparator[(Int, K)] = (a: (Int, K), b: (Int, K)) => a._1 - b._1
/**
* A comparator for (Int, K) pairs that orders them both by their partition ID and a key ordering.
*/
def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = {
new Comparator[(Int, K)] {
override def compare(a: (Int, K), b: (Int, K)): Int = {
val partitionDiff = a._1 - b._1
if (partitionDiff != 0) {
partitionDiff
} else {
keyComparator.compare(a._2, b._2)
}
def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] =
(a: (Int, K), b: (Int, K)) => {
val partitionDiff = a._1 - b._1
if (partitionDiff != 0) {
partitionDiff
} else {
keyComparator.compare(a._2, b._2)
}
}
}
}
/**

View file

@ -39,7 +39,7 @@ private[spark] class DriverLogger(conf: SparkConf) extends Logging {
private val DEFAULT_LAYOUT = "%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n"
private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
private var localLogFile: String = FileUtils.getFile(
private val localLogFile: String = FileUtils.getFile(
Utils.getLocalDir(conf),
DriverLogger.DRIVER_LOG_DIR,
DriverLogger.DRIVER_LOG_FILE).getAbsolutePath()
@ -163,9 +163,7 @@ private[spark] class DriverLogger(conf: SparkConf) extends Logging {
def closeWriter(): Unit = {
try {
threadpool.execute(new Runnable() {
override def run(): Unit = DfsAsyncWriter.this.close()
})
threadpool.execute(() => DfsAsyncWriter.this.close())
threadpool.shutdown()
threadpool.awaitTermination(1, TimeUnit.MINUTES)
} catch {

View file

@ -164,10 +164,9 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
test("Thread safeness - SPARK-5425") {
val executor = Executors.newSingleThreadScheduledExecutor()
val sf = executor.scheduleAtFixedRate(new Runnable {
override def run(): Unit =
System.setProperty("spark.5425." + Random.nextInt(), Random.nextInt().toString)
}, 0, 1, TimeUnit.MILLISECONDS)
executor.scheduleAtFixedRate(
() => System.setProperty("spark.5425." + Random.nextInt(), Random.nextInt().toString),
0, 1, TimeUnit.MILLISECONDS)
try {
val t0 = System.nanoTime()

View file

@ -27,7 +27,6 @@ import org.eclipse.jetty.servlet.ServletContextHandler
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.Matchers
import org.scalatest.mockito.MockitoSugar
@ -374,11 +373,9 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar
when(request.getRequestURI()).thenReturn("http://localhost:18080/history/local-123/jobs/job/")
when(request.getQueryString()).thenReturn("id=2")
val resp = mock[HttpServletResponse]
when(resp.encodeRedirectURL(any())).thenAnswer(new Answer[String]() {
override def answer(invocationOnMock: InvocationOnMock): String = {
invocationOnMock.getArguments()(0).asInstanceOf[String]
}
})
when(resp.encodeRedirectURL(any())).thenAnswer { (invocationOnMock: InvocationOnMock) =>
invocationOnMock.getArguments()(0).asInstanceOf[String]
}
filter.doFilter(request, resp, null)
verify(resp).sendRedirect("http://localhost:18080/history/local-123/jobs/job/?id=2")
}

View file

@ -33,7 +33,6 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, FSDataInputStream, Path}
import org.apache.hadoop.hdfs.{DFSInputStream, DistributedFileSystem}
import org.apache.hadoop.security.AccessControlException
import org.json4s.jackson.JsonMethods._
import org.mockito.ArgumentMatcher
import org.mockito.ArgumentMatchers.{any, argThat}
import org.mockito.Mockito.{doThrow, mock, spy, verify, when}
import org.scalatest.BeforeAndAfter
@ -49,7 +48,7 @@ import org.apache.spark.io._
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.security.GroupMappingServiceProvider
import org.apache.spark.status.{AppStatusStore, ExecutorSummaryWrapper}
import org.apache.spark.status.AppStatusStore
import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo}
import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils}
import org.apache.spark.util.logging.DriverLogger
@ -1122,11 +1121,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
SparkListenerApplicationEnd(5L))
val mockedFs = spy(provider.fs)
doThrow(new AccessControlException("Cannot read accessDenied file")).when(mockedFs).open(
argThat(new ArgumentMatcher[Path]() {
override def matches(path: Path): Boolean = {
path.asInstanceOf[Path].getName.toLowerCase(Locale.ROOT) == "accessdenied"
}
}))
argThat((path: Path) => path.getName.toLowerCase(Locale.ROOT) == "accessdenied"))
val mockedProvider = spy(provider)
when(mockedProvider.fs).thenReturn(mockedFs)
updateAndCheck(mockedProvider) { list =>

View file

@ -24,7 +24,6 @@ import scala.concurrent.duration._
import org.mockito.ArgumentMatchers.{any, anyInt}
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
@ -57,11 +56,9 @@ class DriverRunnerTest extends SparkFunSuite {
superviseRetry: Boolean) = {
val runner = createDriverRunner()
runner.setSleeper(mock(classOf[Sleeper]))
doAnswer(new Answer[Int] {
def answer(invocation: InvocationOnMock): Int = {
runner.runCommandWithRetry(processBuilder, p => (), supervise = superviseRetry)
}
}).when(runner).prepareAndRunDriver()
doAnswer { (_: InvocationOnMock) =>
runner.runCommandWithRetry(processBuilder, p => (), supervise = superviseRetry)
}.when(runner).prepareAndRunDriver()
runner
}
@ -120,11 +117,9 @@ class DriverRunnerTest extends SparkFunSuite {
runner.setSleeper(sleeper)
val (processBuilder, process) = createProcessBuilderAndProcess()
when(process.waitFor()).thenAnswer(new Answer[Int] {
def answer(invocation: InvocationOnMock): Int = {
runner.kill()
-1
}
when(process.waitFor()).thenAnswer((_: InvocationOnMock) => {
runner.kill()
-1
})
runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
@ -169,11 +164,9 @@ class DriverRunnerTest extends SparkFunSuite {
val (processBuilder, process) = createProcessBuilderAndProcess()
val runner = createTestableDriverRunner(processBuilder, superviseRetry = true)
when(process.waitFor()).thenAnswer(new Answer[Int] {
def answer(invocation: InvocationOnMock): Int = {
runner.kill()
-1
}
when(process.waitFor()).thenAnswer((_: InvocationOnMock) => {
runner.kill()
-1
})
runner.start()

View file

@ -28,7 +28,6 @@ import org.mockito.Answers.RETURNS_SMART_NULLS
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
@ -233,11 +232,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
val conf = new SparkConf().set(config.STORAGE_CLEANUP_FILES_AFTER_EXECUTOR_EXIT, value)
val cleanupCalled = new AtomicBoolean(false)
when(shuffleService.executorRemoved(any[String], any[String])).thenAnswer(new Answer[Unit] {
override def answer(invocations: InvocationOnMock): Unit = {
cleanupCalled.set(true)
}
})
when(shuffleService.executorRemoved(any[String], any[String])).thenAnswer(
(_: InvocationOnMock) => cleanupCalled.set(true))
val externalShuffleServiceSupplier = new Supplier[ExternalShuffleService] {
override def get: ExternalShuffleService = shuffleService
}
@ -269,11 +265,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
val appId = "app1"
val execId = "exec1"
val cleanupCalled = new AtomicBoolean(false)
when(shuffleService.applicationRemoved(any[String])).thenAnswer(new Answer[Unit] {
override def answer(invocations: InvocationOnMock): Unit = {
cleanupCalled.set(true)
}
})
when(shuffleService.applicationRemoved(any[String])).thenAnswer(
(_: InvocationOnMock) => cleanupCalled.set(true))
val externalShuffleServiceSupplier = new Supplier[ExternalShuffleService] {
override def get: ExternalShuffleService = shuffleService
}
@ -289,8 +282,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
}
executorDir.setLastModified(System.currentTimeMillis - (1000 * 120))
worker.receive(WorkDirCleanup)
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
assert(!executorDir.exists() == true)
eventually(timeout(1000.milliseconds), interval(10.milliseconds)) {
assert(!executorDir.exists())
assert(cleanupCalled.get() == dbCleanupEnabled)
}
}

View file

@ -279,13 +279,10 @@ class ExecutorSuite extends SparkFunSuite
val heartbeats = ArrayBuffer[Heartbeat]()
val mockReceiver = mock[RpcEndpointRef]
when(mockReceiver.askSync(any[Heartbeat], any[RpcTimeout])(any))
.thenAnswer(new Answer[HeartbeatResponse] {
override def answer(invocation: InvocationOnMock): HeartbeatResponse = {
val args = invocation.getArguments()
val mock = invocation.getMock
heartbeats += args(0).asInstanceOf[Heartbeat]
HeartbeatResponse(false)
}
.thenAnswer((invocation: InvocationOnMock) => {
val args = invocation.getArguments()
heartbeats += args(0).asInstanceOf[Heartbeat]
HeartbeatResponse(false)
})
val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef")
receiverRef.setAccessible(true)

View file

@ -84,11 +84,8 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
*/
protected def makeBadMemoryStore(mm: MemoryManager): MemoryStore = {
val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS)
when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())).thenAnswer(new Answer[Long] {
override def answer(invocation: InvocationOnMock): Long = {
throw new RuntimeException("bad memory store!")
}
})
when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())).thenAnswer(
(_: InvocationOnMock) => throw new RuntimeException("bad memory store!"))
mm.setMemoryStore(ms)
ms
}
@ -106,27 +103,24 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
* records the number of bytes this is called with. This variable is expected to be cleared
* by the test code later through [[assertEvictBlocksToFreeSpaceCalled]].
*/
private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Long] = {
new Answer[Long] {
override def answer(invocation: InvocationOnMock): Long = {
val args = invocation.getArguments
val numBytesToFree = args(1).asInstanceOf[Long]
assert(numBytesToFree > 0)
require(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED,
"bad test: evictBlocksToFreeSpace() variable was not reset")
evictBlocksToFreeSpaceCalled.set(numBytesToFree)
if (numBytesToFree <= mm.storageMemoryUsed) {
// We can evict enough blocks to fulfill the request for space
mm.releaseStorageMemory(numBytesToFree, mm.tungstenMemoryMode)
evictedBlocks += Tuple2(null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L))
numBytesToFree
} else {
// No blocks were evicted because eviction would not free enough space.
0L
}
private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Long] =
(invocation: InvocationOnMock) => {
val args = invocation.getArguments
val numBytesToFree = args(1).asInstanceOf[Long]
assert(numBytesToFree > 0)
require(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED,
"bad test: evictBlocksToFreeSpace() variable was not reset")
evictBlocksToFreeSpaceCalled.set(numBytesToFree)
if (numBytesToFree <= mm.storageMemoryUsed) {
// We can evict enough blocks to fulfill the request for space
mm.releaseStorageMemory(numBytesToFree, mm.tungstenMemoryMode)
evictedBlocks += Tuple2(null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L))
numBytesToFree
} else {
// No blocks were evicted because eviction would not free enough space.
0L
}
}
}
/**
* Assert that [[MemoryStore.evictBlocksToFreeSpace]] is called with the given parameters.

View file

@ -20,7 +20,6 @@ package org.apache.spark.scheduler
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{never, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfterEach
import org.scalatest.mockito.MockitoSugar
@ -480,17 +479,16 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
test("blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") {
val allocationClientMock = mock[ExecutorAllocationClient]
when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called"))
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] {
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer { (_: InvocationOnMock) =>
// To avoid a race between blacklisting and killing, it is important that the nodeBlacklist
// is updated before we ask the executor allocation client to kill all the executors
// on a particular host.
override def answer(invocation: InvocationOnMock): Boolean = {
if (blacklist.nodeBlacklist.contains("hostA") == false) {
throw new IllegalStateException("hostA should be on the blacklist")
}
if (blacklist.nodeBlacklist.contains("hostA")) {
true
} else {
throw new IllegalStateException("hostA should be on the blacklist")
}
})
}
blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock)
// Disable auto-kill. Blacklist an executor and make sure killExecutors is not called.
@ -552,17 +550,16 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") {
val allocationClientMock = mock[ExecutorAllocationClient]
when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called"))
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] {
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer { (_: InvocationOnMock) =>
// To avoid a race between blacklisting and killing, it is important that the nodeBlacklist
// is updated before we ask the executor allocation client to kill all the executors
// on a particular host.
override def answer(invocation: InvocationOnMock): Boolean = {
if (blacklist.nodeBlacklist.contains("hostA") == false) {
throw new IllegalStateException("hostA should be on the blacklist")
}
if (blacklist.nodeBlacklist.contains("hostA")) {
true
} else {
throw new IllegalStateException("hostA should be on the blacklist")
}
})
}
conf.set(config.BLACKLIST_FETCH_FAILURE_ENABLED, true)
blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock)

View file

@ -29,7 +29,6 @@ import org.apache.hadoop.mapreduce.TaskType
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{doAnswer, spy, times, verify}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfter
import org.apache.spark._
@ -98,34 +97,29 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
// Use Mockito.spy() to maintain the default infrastructure everywhere else
val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl])
doAnswer(new Answer[Unit]() {
override def answer(invoke: InvocationOnMock): Unit = {
// Submit the tasks, then force the task scheduler to dequeue the
// speculated task
invoke.callRealMethod()
mockTaskScheduler.backend.reviveOffers()
}
}).when(mockTaskScheduler).submitTasks(any())
doAnswer { (invoke: InvocationOnMock) =>
// Submit the tasks, then force the task scheduler to dequeue the
// speculated task
invoke.callRealMethod()
mockTaskScheduler.backend.reviveOffers()
}.when(mockTaskScheduler).submitTasks(any())
doAnswer(new Answer[TaskSetManager]() {
override def answer(invoke: InvocationOnMock): TaskSetManager = {
val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet]
new TaskSetManager(mockTaskScheduler, taskSet, 4) {
var hasDequeuedSpeculatedTask = false
override def dequeueSpeculativeTask(
execId: String,
host: String,
locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = {
if (!hasDequeuedSpeculatedTask) {
hasDequeuedSpeculatedTask = true
Some((0, TaskLocality.PROCESS_LOCAL))
} else {
None
}
doAnswer { (invoke: InvocationOnMock) =>
val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet]
new TaskSetManager(mockTaskScheduler, taskSet, 4) {
private var hasDequeuedSpeculatedTask = false
override def dequeueSpeculativeTask(execId: String,
host: String,
locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = {
if (hasDequeuedSpeculatedTask) {
None
} else {
hasDequeuedSpeculatedTask = true
Some((0, TaskLocality.PROCESS_LOCAL))
}
}
}
}).when(mockTaskScheduler).createTaskSetManager(any(), any())
}.when(mockTaskScheduler).createTaskSetManager(any(), any())
sc.taskScheduler = mockTaskScheduler
val dagSchedulerWithMockTaskScheduler = new DAGScheduler(sc, mockTaskScheduler)

View file

@ -356,13 +356,9 @@ private[spark] abstract class MockBackend(
assignedTasksWaitingToRun.nonEmpty
}
override def start(): Unit = {
reviveThread.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
reviveOffers()
}
}, 0, reviveIntervalMs, TimeUnit.MILLISECONDS)
}
override def start(): Unit =
reviveThread.scheduleAtFixedRate(() => Utils.tryLogNonFatalError { reviveOffers() },
0, reviveIntervalMs, TimeUnit.MILLISECONDS)
override def stop(): Unit = {
reviveThread.shutdown()

View file

@ -25,14 +25,13 @@ import scala.collection.mutable.ArrayBuffer
import org.mockito.ArgumentMatchers.{any, anyInt, anyString}
import org.mockito.Mockito.{mock, never, spy, times, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{AccumulatorV2, ManualClock, Utils}
import org.apache.spark.util.{AccumulatorV2, ManualClock}
class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
extends DAGScheduler(sc) {
@ -1190,11 +1189,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1))
when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).thenAnswer(
new Answer[Unit] {
override def answer(invocationOnMock: InvocationOnMock): Unit = {
assert(manager.isZombie)
}
})
(invocationOnMock: InvocationOnMock) => assert(manager.isZombie))
val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
assert(taskOption.isDefined)
// this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon
@ -1317,12 +1312,10 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
// Assert the task has been black listed on the executor it was last executed on.
when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer(
new Answer[Unit] {
override def answer(invocationOnMock: InvocationOnMock): Unit = {
val task: Int = invocationOnMock.getArgument(0)
assert(taskSetManager.taskSetBlacklistHelperOpt.get.
isExecutorBlacklistedForTask(exec, task))
}
(invocationOnMock: InvocationOnMock) => {
val task: Int = invocationOnMock.getArgument(0)
assert(taskSetManager.taskSetBlacklistHelperOpt.get.
isExecutorBlacklistedForTask(exec, task))
}
)

View file

@ -28,7 +28,6 @@ import org.mockito.Answers.RETURNS_SMART_NULLS
import org.mockito.ArgumentMatchers.{any, anyInt}
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfterEach
import org.apache.spark._
@ -69,16 +68,14 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
when(dependency.serializer).thenReturn(new JavaSerializer(conf))
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
doAnswer(new Answer[Void] {
def answer(invocationOnMock: InvocationOnMock): Void = {
val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
if (tmp != null) {
outputFile.delete
tmp.renameTo(outputFile)
}
null
doAnswer { (invocationOnMock: InvocationOnMock) =>
val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
if (tmp != null) {
outputFile.delete
tmp.renameTo(outputFile)
}
}).when(blockResolver)
null
}.when(blockResolver)
.writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(blockManager.getDiskWriter(
@ -87,37 +84,29 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
any[SerializerInstance],
anyInt(),
any[ShuffleWriteMetrics]
)).thenAnswer(new Answer[DiskBlockObjectWriter] {
override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = {
val args = invocation.getArguments
val manager = new SerializerManager(new JavaSerializer(conf), conf)
new DiskBlockObjectWriter(
args(1).asInstanceOf[File],
manager,
args(2).asInstanceOf[SerializerInstance],
args(3).asInstanceOf[Int],
syncWrites = false,
args(4).asInstanceOf[ShuffleWriteMetrics],
blockId = args(0).asInstanceOf[BlockId]
)
}
)).thenAnswer((invocation: InvocationOnMock) => {
val args = invocation.getArguments
val manager = new SerializerManager(new JavaSerializer(conf), conf)
new DiskBlockObjectWriter(
args(1).asInstanceOf[File],
manager,
args(2).asInstanceOf[SerializerInstance],
args(3).asInstanceOf[Int],
syncWrites = false,
args(4).asInstanceOf[ShuffleWriteMetrics],
blockId = args(0).asInstanceOf[BlockId]
)
})
when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
new Answer[(TempShuffleBlockId, File)] {
override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = {
val blockId = new TempShuffleBlockId(UUID.randomUUID)
val file = new File(tempDir, blockId.name)
blockIdToFileMap.put(blockId, file)
temporaryFilesCreated += file
(blockId, file)
}
})
when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
new Answer[File] {
override def answer(invocation: InvocationOnMock): File = {
blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId])
}
when(diskBlockManager.createTempShuffleBlock()).thenAnswer((_: InvocationOnMock) => {
val blockId = new TempShuffleBlockId(UUID.randomUUID)
val file = new File(tempDir, blockId.name)
blockIdToFileMap.put(blockId, file)
temporaryFilesCreated += file
(blockId, file)
})
when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) =>
blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId])
}
}
override def afterEach(): Unit = {

View file

@ -24,7 +24,6 @@ import org.mockito.Answers.RETURNS_SMART_NULLS
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.{SparkConf, SparkFunSuite}
@ -48,11 +47,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
new Answer[File] {
override def answer(invocation: InvocationOnMock): File = {
new File(tempDir, invocation.getArguments.head.toString)
}
})
(invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString))
}
override def afterEach(): Unit = {

View file

@ -24,7 +24,6 @@ import scala.reflect.ClassTag
import org.mockito.Mockito
import org.mockito.Mockito.atLeastOnce
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
@ -59,11 +58,9 @@ class PartiallySerializedBlockSuite
val bbos: ChunkedByteBufferOutputStream = {
val spy = Mockito.spy(new ChunkedByteBufferOutputStream(128, ByteBuffer.allocate))
Mockito.doAnswer(new Answer[ChunkedByteBuffer] {
override def answer(invocationOnMock: InvocationOnMock): ChunkedByteBuffer = {
Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer])
}
}).when(spy).toChunkedByteBuffer
Mockito.doAnswer { (invocationOnMock: InvocationOnMock) =>
Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer])
}.when(spy).toChunkedByteBuffer
spy
}

View file

@ -28,7 +28,6 @@ import scala.concurrent.Future
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{mock, times, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.PrivateMethodTester
import org.apache.spark.{SparkFunSuite, TaskContext}
@ -50,9 +49,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
/** Creates a mock [[BlockTransferService]] that returns data from the given map. */
private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = {
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer(
(invocation: InvocationOnMock) => {
val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]]
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
@ -63,8 +61,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId))
}
}
}
})
})
transfer
}
@ -168,8 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
.thenAnswer((invocation: InvocationOnMock) => {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
// Return the first two blocks, and wait till task completion before returning the 3rd one
@ -181,8 +177,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
}
}
})
})
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
@ -237,20 +232,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
// Return the first two blocks, and wait till task completion before returning the last
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
sem.acquire()
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
}
}
.thenAnswer((invocation: InvocationOnMock) => {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
// Return the first two blocks, and wait till task completion before returning the last
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
sem.acquire()
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
}
})
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
@ -298,8 +291,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
.thenAnswer((invocation: InvocationOnMock) => {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
// Return the first block, and then fail.
@ -311,8 +303,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah"))
sem.release()
}
}
})
})
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
@ -389,8 +380,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
.thenAnswer((invocation: InvocationOnMock) => {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
// Return the first block, and then fail.
@ -402,8 +392,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
sem.release()
}
}
})
})
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
@ -431,8 +420,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
assert(id1 === ShuffleBlockId(0, 0, 0))
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
.thenAnswer((invocation: InvocationOnMock) => {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
// Return the first block, and then fail.
@ -440,8 +428,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
sem.release()
}
}
})
})
// The next block is corrupt local block (the second one is corrupt and retried)
intercept[FetchFailedException] { iterator.next() }
@ -588,8 +575,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
.thenAnswer((invocation: InvocationOnMock) => {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
// Return the first block, and then fail.
@ -601,8 +587,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer())
sem.release()
}
}
})
})
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
@ -654,14 +639,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val transfer = mock(classOf[BlockTransferService])
var tempFileManager: DownloadFileManager = null
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager]
Future {
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))
}
.thenAnswer((invocation: InvocationOnMock) => {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager]
Future {
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))
}
})

View file

@ -71,11 +71,9 @@ class ThreadUtilsSuite extends SparkFunSuite {
keepAliveSeconds = 2)
try {
for (_ <- 1 to maxThreadNumber) {
cachedThreadPool.execute(new Runnable {
override def run(): Unit = {
startThreadsLatch.countDown()
latch.await(10, TimeUnit.SECONDS)
}
cachedThreadPool.execute(() => {
startThreadsLatch.countDown()
latch.await(10, TimeUnit.SECONDS)
})
}
startThreadsLatch.await(10, TimeUnit.SECONDS)
@ -84,11 +82,7 @@ class ThreadUtilsSuite extends SparkFunSuite {
// Submit a new task and it should be put into the queue since the thread number reaches the
// limitation
cachedThreadPool.execute(new Runnable {
override def run(): Unit = {
latch.await(10, TimeUnit.SECONDS)
}
})
cachedThreadPool.execute(() => latch.await(10, TimeUnit.SECONDS))
assert(cachedThreadPool.getActiveCount === maxThreadNumber)
assert(cachedThreadPool.getQueue.size === 1)

View file

@ -17,8 +17,6 @@
package org.apache.spark.util.collection
import java.util.Comparator
import scala.collection.mutable.HashSet
import org.apache.spark.SparkFunSuite
@ -170,12 +168,10 @@ class AppendOnlyMapSuite extends SparkFunSuite {
case e: IllegalStateException => fail()
}
val it = map.destructiveSortedIterator(new Comparator[String] {
def compare(key1: String, key2: String): Int = {
val x = if (key1 != null) key1.toInt else Int.MinValue
val y = if (key2 != null) key2.toInt else Int.MinValue
x.compareTo(y)
}
val it = map.destructiveSortedIterator((key1: String, key2: String) => {
val x = if (key1 != null) key1.toInt else Int.MinValue
val y = if (key2 != null) key2.toInt else Int.MinValue
x.compareTo(y)
})
// Should be sorted by key

View file

@ -17,8 +17,6 @@
package org.apache.spark.util.collection
import java.util.Comparator
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
@ -111,14 +109,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val tmp = new Array[Long](size/2)
val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp))
new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort(
buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] {
override def compare(
r1: RecordPointerAndKeyPrefix,
r2: RecordPointerAndKeyPrefix): Int = {
PrefixComparators.LONG.compare(r1.keyPrefix, r2.keyPrefix)
}
})
new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort(buf, 0, size,
(r1: RecordPointerAndKeyPrefix, r2: RecordPointerAndKeyPrefix) =>
PrefixComparators.LONG.compare(r1.keyPrefix, r2.keyPrefix))
}
test("spilling with hash collisions") {
@ -135,7 +128,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
buffer2: ArrayBuffer[String]): ArrayBuffer[String] = buffer1 ++= buffer2
val agg = new Aggregator[String, String, ArrayBuffer[String]](
createCombiner _, mergeValue _, mergeCombiners _)
createCombiner, mergeValue, mergeCombiners)
val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
context, Some(agg), None, None)

View file

@ -18,7 +18,7 @@
package org.apache.spark.util.collection
import java.lang.{Float => JFloat}
import java.util.{Arrays, Comparator}
import java.util.Arrays
import java.util.concurrent.TimeUnit
import org.apache.spark.SparkFunSuite
@ -219,10 +219,8 @@ class SorterSuite extends SparkFunSuite with Logging {
System.arraycopy(kvTuples, 0, kvTupleArray, 0, numElements)
}
runExperiment("Tuple-sort using Arrays.sort()")({
Arrays.sort(kvTupleArray, new Comparator[AnyRef] {
override def compare(x: AnyRef, y: AnyRef): Int =
x.asInstanceOf[(JFloat, _)]._1.compareTo(y.asInstanceOf[(JFloat, _)]._1)
})
Arrays.sort(kvTupleArray, (x: AnyRef, y: AnyRef) =>
x.asInstanceOf[(JFloat, _)]._1.compareTo(y.asInstanceOf[(JFloat, _)]._1))
}, prepareKvTupleArray)
// Test our Sorter where each element alternates between Float and Integer, non-primitive
@ -245,9 +243,7 @@ class SorterSuite extends SparkFunSuite with Logging {
val sorter = new Sorter(new KVArraySortDataFormat[JFloat, AnyRef])
runExperiment("KV-sort using Sorter")({
sorter.sort(keyValueArray, 0, numElements, new Comparator[JFloat] {
override def compare(x: JFloat, y: JFloat): Int = x.compareTo(y)
})
sorter.sort(keyValueArray, 0, numElements, (x: JFloat, y: JFloat) => x.compareTo(y))
}, prepareKeyValueArray)
}
@ -280,11 +276,9 @@ class SorterSuite extends SparkFunSuite with Logging {
System.arraycopy(intObjects, 0, intObjectArray, 0, numElements)
}
runExperiment("Java Arrays.sort() on non-primitive int array")({
Arrays.sort(intObjectArray, new Comparator[Integer] {
override def compare(x: Integer, y: Integer): Int = x.compareTo(y)
})
}, prepareIntObjectArray)
runExperiment("Java Arrays.sort() on non-primitive int array")(
Arrays.sort(intObjectArray, (x: Integer, y: Integer) => x.compareTo(y)),
prepareIntObjectArray)
val intPrimitiveArray = new Array[Int](numElements)
val prepareIntPrimitiveArray = () => {

View file

@ -69,9 +69,8 @@ class RadixSortSuite extends SparkFunSuite with Logging {
override def sortDescending = false
override def sortSigned = false
override def nullsFirst = true
override def compare(a: Long, b: Long): Int = {
return PrefixComparators.BINARY.compare(a & 0xffffff0000L, b & 0xffffff0000L)
}
override def compare(a: Long, b: Long): Int =
PrefixComparators.BINARY.compare(a & 0xffffff0000L, b & 0xffffff0000L)
},
2, 4, false, false, true))
@ -112,11 +111,9 @@ class RadixSortSuite extends SparkFunSuite with Logging {
private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) {
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] {
override def compare(
r1: RecordPointerAndKeyPrefix,
r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix)
})
buf, Ints.checkedCast(lo), Ints.checkedCast(hi),
(r1: RecordPointerAndKeyPrefix, r2: RecordPointerAndKeyPrefix) =>
refCmp.compare(r1.keyPrefix, r2.keyPrefix))
}
private def fuzzTest(name: String)(testFn: Long => Unit): Unit = {

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.kafka010
import java.{util => ju}
import java.util.concurrent.{Executors, ThreadFactory}
import java.util.concurrent.Executors
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
@ -52,16 +52,14 @@ private[kafka010] class KafkaOffsetReader(
/**
* Used to ensure execute fetch operations execute in an UninterruptibleThread
*/
val kafkaReaderThread = Executors.newSingleThreadExecutor(new ThreadFactory {
override def newThread(r: Runnable): Thread = {
val t = new UninterruptibleThread("Kafka Offset Reader") {
override def run(): Unit = {
r.run()
}
val kafkaReaderThread = Executors.newSingleThreadExecutor((r: Runnable) => {
val t = new UninterruptibleThread("Kafka Offset Reader") {
override def run(): Unit = {
r.run()
}
t.setDaemon(true)
t
}
t.setDaemon(true)
t
})
val execContext = ExecutionContext.fromExecutorService(kafkaReaderThread)

View file

@ -361,9 +361,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
override def capabilities(): ju.Set[TableCapability] = Collections.emptySet()
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder {
override def build(): Scan = new KafkaScan(options)
}
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
() => new KafkaScan(options)
override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
new WriteBuilder {

View file

@ -53,7 +53,7 @@ class DirectKafkaStreamSuite
.setMaster("local[4]")
.setAppName(this.getClass.getSimpleName)
// Set a timeout of 10 seconds that's going to be used to fetch topics/partitions from kafka.
// Othewise the poll timeout defaults to 2 minutes and causes test cases to run longer.
// Otherwise the poll timeout defaults to 2 minutes and causes test cases to run longer.
.set("spark.streaming.kafka.consumer.poll.ms", "10000")
private var ssc: StreamingContext = _
@ -61,13 +61,13 @@ class DirectKafkaStreamSuite
private var kafkaTestUtils: KafkaTestUtils = _
override def beforeAll {
override def beforeAll() {
super.beforeAll()
kafkaTestUtils = new KafkaTestUtils
kafkaTestUtils.setup()
}
override def afterAll {
override def afterAll() {
try {
if (kafkaTestUtils != null) {
kafkaTestUtils.teardown()
@ -454,13 +454,11 @@ class DirectKafkaStreamSuite
val data = rdd.map(_.value).collect()
collectedData.addAll(Arrays.asList(data: _*))
kafkaStream.asInstanceOf[CanCommitOffsets]
.commitAsync(offsets, new OffsetCommitCallback() {
def onComplete(m: JMap[TopicPartition, OffsetAndMetadata], e: Exception) {
if (null != e) {
logError("commit failed", e)
} else {
committed.putAll(m)
}
.commitAsync(offsets, (m: JMap[TopicPartition, OffsetAndMetadata], e: Exception) => {
if (null != e) {
logError("commit failed", e)
} else {
committed.putAll(m)
}
})
}

View file

@ -27,7 +27,6 @@ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorC
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.scalatest.concurrent.Eventually
import org.scalatest.mockito.MockitoSugar
@ -124,11 +123,9 @@ class KinesisCheckpointerSuite extends TestSuiteBase
test("if checkpointing is going on, wait until finished before removing and checkpointing") {
when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
.thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
when(checkpointerMock.checkpoint(anyString)).thenAnswer(new Answer[Unit] {
override def answer(invocations: InvocationOnMock): Unit = {
clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2)
}
})
when(checkpointerMock.checkpoint(anyString)).thenAnswer { (_: InvocationOnMock) =>
clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2)
}
kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
clock.advance(checkpointInterval.milliseconds)

View file

@ -130,16 +130,14 @@ private[impl] case class EdgeWithLocalIds[@specialized ED](
private[impl] object EdgeWithLocalIds {
implicit def lexicographicOrdering[ED]: Ordering[EdgeWithLocalIds[ED]] =
new Ordering[EdgeWithLocalIds[ED]] {
override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = {
if (a.srcId == b.srcId) {
if (a.dstId == b.dstId) 0
else if (a.dstId < b.dstId) -1
else 1
} else if (a.srcId < b.srcId) -1
(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]) =>
if (a.srcId == b.srcId) {
if (a.dstId == b.dstId) 0
else if (a.dstId < b.dstId) -1
else 1
}
}
else if (a.srcId < b.srcId) -1
else 1
private[graphx] def edgeArraySortDataFormat[ED] = {
new SortDataFormat[EdgeWithLocalIds[ED], Array[EdgeWithLocalIds[ED]]] {

View file

@ -19,7 +19,7 @@ package org.apache.spark.repl
import java.io.File
import java.net.{URI, URL, URLClassLoader}
import java.nio.channels.{FileChannel, ReadableByteChannel}
import java.nio.channels.FileChannel
import java.nio.charset.StandardCharsets
import java.nio.file.{Paths, StandardOpenOption}
import java.util
@ -33,7 +33,6 @@ import com.google.common.io.Files
import org.mockito.ArgumentMatchers.anyString
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfterAll
import org.scalatest.mockito.MockitoSugar
@ -191,12 +190,10 @@ class ExecutorClassLoaderSuite
val env = mock[SparkEnv]
val rpcEnv = mock[RpcEnv]
when(env.rpcEnv).thenReturn(rpcEnv)
when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() {
override def answer(invocation: InvocationOnMock): ReadableByteChannel = {
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/"))
FileChannel.open(path, StandardOpenOption.READ)
}
when(rpcEnv.openChannel(anyString())).thenAnswer((invocation: InvocationOnMock) => {
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/"))
FileChannel.open(path, StandardOpenOption.READ)
})
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",

View file

@ -52,9 +52,7 @@ class SingletonReplSuite extends SparkFunSuite {
Main.sparkSession = null
// Starts a new thread to run the REPL interpreter, so that we won't block.
thread = new Thread(new Runnable {
override def run(): Unit = Main.doMain(Array("-classpath", classpath), interp)
})
thread = new Thread(() => Main.doMain(Array("-classpath", classpath), interp))
thread.setDaemon(true)
thread.start()

View file

@ -47,9 +47,7 @@ private[k8s] class LoggingPodStatusWatcherImpl(
// start timer for periodic logging
private val scheduler =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("logging-pod-status-watcher")
private val logRunnable: Runnable = new Runnable {
override def run() = logShortStatus()
}
private val logRunnable: Runnable = () => logShortStatus()
private var pod = Option.empty[Pod]

View file

@ -68,7 +68,7 @@ private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: Schedul
}
subscribers += newSubscriber
pollingTasks += subscribersExecutor.scheduleWithFixedDelay(
toRunnable(() => callSubscriber(newSubscriber)),
() => callSubscriber(newSubscriber),
0L,
processBatchIntervalMillis,
TimeUnit.MILLISECONDS)
@ -103,10 +103,6 @@ private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: Schedul
}
}
private def toRunnable[T](runnable: () => Unit): Runnable = new Runnable {
override def run(): Unit = runnable()
}
private case class SnapshotsSubscriber(
snapshotsBuffer: BlockingQueue[ExecutorPodsSnapshot],
onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit)

View file

@ -23,7 +23,6 @@ import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, Secr
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{mock, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.apache.spark.deploy.k8s.SparkPod
@ -38,16 +37,14 @@ object KubernetesFeaturesTestUtils {
when(mockStep.getAdditionalPodSystemProperties())
.thenReturn(Map(stepType -> stepType))
when(mockStep.configurePod(any(classOf[SparkPod])))
.thenAnswer(new Answer[SparkPod]() {
override def answer(invocation: InvocationOnMock): SparkPod = {
val originalPod: SparkPod = invocation.getArgument(0)
val configuredPod = new PodBuilder(originalPod.pod)
.editOrNewMetadata()
.addToLabels(stepType, stepType)
.endMetadata()
.build()
SparkPod(configuredPod, originalPod.container)
}
.thenAnswer((invocation: InvocationOnMock) => {
val originalPod: SparkPod = invocation.getArgument(0)
val configuredPod = new PodBuilder(originalPod.pod)
.editOrNewMetadata()
.addToLabels(stepType, stepType)
.endMetadata()
.build()
SparkPod(configuredPod, originalPod.container)
})
mockStep
}
@ -67,6 +64,6 @@ object KubernetesFeaturesTestUtils {
def filter[T: ClassTag](list: Seq[HasMetadata]): Seq[T] = {
val desired = implicitly[ClassTag[T]].runtimeClass
list.filter(_.getClass() == desired).map(_.asInstanceOf[T]).toSeq
list.filter(_.getClass() == desired).map(_.asInstanceOf[T])
}
}

View file

@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster.k8s
import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder}
import io.fabric8.kubernetes.client.KubernetesClient
import io.fabric8.kubernetes.client.dsl.PodResource
import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations}
import org.mockito.{Mock, MockitoAnnotations}
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{never, times, verify, when}
import org.mockito.invocation.InvocationOnMock
@ -27,7 +27,7 @@ import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfter
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, SparkPod}
import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod}
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.deploy.k8s.Fabric8Aliases._
@ -153,12 +153,9 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter {
verify(podOperations).create(podWithAttachedContainerForId(2))
}
private def executorPodAnswer(): Answer[SparkPod] = {
new Answer[SparkPod] {
override def answer(invocation: InvocationOnMock): SparkPod = {
val k8sConf: KubernetesExecutorConf = invocation.getArgument(0)
executorPodWithId(k8sConf.executorId.toInt)
}
}
private def executorPodAnswer(): Answer[SparkPod] =
(invocation: InvocationOnMock) => {
val k8sConf: KubernetesExecutorConf = invocation.getArgument(0)
executorPodWithId(k8sConf.executorId.toInt)
}
}

View file

@ -26,7 +26,6 @@ import org.mockito.Mockito.{mock, never, times, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfter
import scala.collection.JavaConverters._
import scala.collection.mutable
import org.apache.spark.{SparkConf, SparkFunSuite}
@ -125,13 +124,10 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte
""".stripMargin
}
private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = {
new Answer[PodResource[Pod, DoneablePod]] {
override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = {
val podName: String = invocation.getArgument(0)
namedExecutorPods.getOrElseUpdate(
podName, mock(classOf[PodResource[Pod, DoneablePod]]))
}
private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] =
(invocation: InvocationOnMock) => {
val podName: String = invocation.getArgument(0)
namedExecutorPods.getOrElseUpdate(
podName, mock(classOf[PodResource[Pod, DoneablePod]]))
}
}
}

View file

@ -192,11 +192,9 @@ private[yarn] class YarnAllocator(
* A sequence of pending container requests at the given location that have not yet been
* fulfilled.
*/
private def getPendingAtLocation(location: String): Seq[ContainerRequest] = {
private def getPendingAtLocation(location: String): Seq[ContainerRequest] =
amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).asScala
.flatMap(_.asScala)
.toSeq
}
/**
* Request as many executors from the ResourceManager as needed to reach the desired total. If
@ -384,7 +382,7 @@ private[yarn] class YarnAllocator(
def stop(): Unit = {
// Forcefully shut down the launcher pool, in case this is being called in the middle of
// container allocation. This will prevent queued executors from being started - and
// potentially interrupt active ExecutorRunnable instaces too.
// potentially interrupt active ExecutorRunnable instances too.
launcherPool.shutdownNow()
}
@ -467,7 +465,7 @@ private[yarn] class YarnAllocator(
remainingAfterOffRackMatches)
}
if (!remainingAfterOffRackMatches.isEmpty) {
if (remainingAfterOffRackMatches.nonEmpty) {
logDebug(s"Releasing ${remainingAfterOffRackMatches.size} unneeded containers that were " +
s"allocated to us")
for (container <- remainingAfterOffRackMatches) {
@ -550,35 +548,33 @@ private[yarn] class YarnAllocator(
if (runningExecutors.size() < targetNumExecutors) {
numExecutorsStarting.incrementAndGet()
if (launchContainers) {
launcherPool.execute(new Runnable {
override def run(): Unit = {
try {
new ExecutorRunnable(
Some(container),
conf,
sparkConf,
driverUrl,
executorId,
executorHostname,
executorMemory,
executorCores,
appAttemptId.getApplicationId.toString,
securityMgr,
localResources
).run()
updateInternalState()
} catch {
case e: Throwable =>
numExecutorsStarting.decrementAndGet()
if (NonFatal(e)) {
logError(s"Failed to launch executor $executorId on container $containerId", e)
// Assigned container should be released immediately
// to avoid unnecessary resource occupation.
amClient.releaseAssignedContainer(containerId)
} else {
throw e
}
}
launcherPool.execute(() => {
try {
new ExecutorRunnable(
Some(container),
conf,
sparkConf,
driverUrl,
executorId,
executorHostname,
executorMemory,
executorCores,
appAttemptId.getApplicationId.toString,
securityMgr,
localResources
).run()
updateInternalState()
} catch {
case e: Throwable =>
numExecutorsStarting.decrementAndGet()
if (NonFatal(e)) {
logError(s"Failed to launch executor $executorId on container $containerId", e)
// Assigned container should be released immediately
// to avoid unnecessary resource occupation.
amClient.releaseAssignedContainer(containerId)
} else {
throw e
}
}
})
} else {
@ -776,7 +772,7 @@ private[yarn] class YarnAllocator(
}
}
(localityMatched.toSeq, localityUnMatched.toSeq, localityFree.toSeq)
(localityMatched, localityUnMatched, localityFree)
}
}

View file

@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@ -273,7 +272,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (children.length == 0) {
if (children.isEmpty) {
emptyInputGenCode(ev)
} else {
nonEmptyInputGenCode(ctx, ev)
@ -718,17 +717,15 @@ trait ArraySortLike extends ExpectsInputTypes {
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
new Comparator[Any]() {
override def compare(o1: Any, o2: Any): Int = {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
nullOrder
} else if (o2 == null) {
-nullOrder
} else {
ordering.compare(o1, o2)
}
(o1: Any, o2: Any) => {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
nullOrder
} else if (o2 == null) {
-nullOrder
} else {
ordering.compare(o1, o2)
}
}
}
@ -740,17 +737,15 @@ trait ArraySortLike extends ExpectsInputTypes {
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
new Comparator[Any]() {
override def compare(o1: Any, o2: Any): Int = {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
-nullOrder
} else if (o2 == null) {
nullOrder
} else {
ordering.compare(o2, o1)
}
(o1: Any, o2: Any) => {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
-nullOrder
} else if (o2 == null) {
nullOrder
} else {
ordering.compare(o2, o1)
}
}
}
@ -769,7 +764,6 @@ trait ArraySortLike extends ExpectsInputTypes {
}
def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = {
val arrayData = classOf[ArrayData].getName
val genericArrayData = classOf[GenericArrayData].getName
val unsafeArrayData = classOf[UnsafeArrayData].getName
val array = ctx.freshName("array")
@ -2784,7 +2778,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
} else {
if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
val element = left.eval(input)
new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))

View file

@ -35,12 +35,8 @@ package object util extends Logging {
val origErr = System.err
val origOut = System.out
try {
System.setErr(new PrintStream(new OutputStream {
def write(b: Int) = {}
}))
System.setOut(new PrintStream(new OutputStream {
def write(b: Int) = {}
}))
System.setErr(new PrintStream((_: Int) => {}))
System.setOut(new PrintStream((_: Int) => {}))
f
} finally {

View file

@ -17,7 +17,6 @@
package org.apache.spark.sql.types
import scala.math.Ordering
import scala.reflect.runtime.universe.typeTag
import org.apache.spark.annotation.Stable
@ -37,11 +36,8 @@ class BinaryType private() extends AtomicType {
@transient private[sql] lazy val tag = typeTag[InternalType]
private[sql] val ordering = new Ordering[InternalType] {
def compare(x: Array[Byte], y: Array[Byte]): Int = {
TypeUtils.compareBinary(x, y)
}
}
private[sql] val ordering =
(x: Array[Byte], y: Array[Byte]) => TypeUtils.compareBinary(x, y)
/**
* The default size of a value of the BinaryType is 100 bytes.

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql.types
import scala.math.{Fractional, Numeric, Ordering}
import scala.math.{Fractional, Numeric}
import scala.math.Numeric.DoubleAsIfIntegral
import scala.reflect.runtime.universe.typeTag
@ -38,9 +38,8 @@ class DoubleType private() extends FractionalType {
@transient private[sql] lazy val tag = typeTag[InternalType]
private[sql] val numeric = implicitly[Numeric[Double]]
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering = new Ordering[Double] {
override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y)
}
private[sql] val ordering =
(x: Double, y: Double) => Utils.nanSafeCompareDoubles(x, y)
private[sql] val asIntegral = DoubleAsIfIntegral
/**

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql.types
import scala.math.{Fractional, Numeric, Ordering}
import scala.math.{Fractional, Numeric}
import scala.math.Numeric.FloatAsIfIntegral
import scala.reflect.runtime.universe.typeTag
@ -38,9 +38,8 @@ class FloatType private() extends FractionalType {
@transient private[sql] lazy val tag = typeTag[InternalType]
private[sql] val numeric = implicitly[Numeric[Float]]
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering = new Ordering[Float] {
override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y)
}
private[sql] val ordering =
(x: Float, y: Float) => Utils.nanSafeCompareFloats(x, y)
private[sql] val asIntegral = FloatAsIfIntegral
/**

View file

@ -38,11 +38,7 @@ class ExternalCatalogEventSuite extends SparkFunSuite {
f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) {
val catalog = new ExternalCatalogWithListener(newCatalog)
val recorder = mutable.Buffer.empty[ExternalCatalogEvent]
catalog.addListener(new ExternalCatalogEventListener {
override def onEvent(event: ExternalCatalogEvent): Unit = {
recorder += event
}
})
catalog.addListener((event: ExternalCatalogEvent) => recorder += event)
f(catalog, (expected: Seq[ExternalCatalogEvent]) => {
val actual = recorder.clone()
recorder.clear()
@ -174,9 +170,6 @@ class ExternalCatalogEventSuite extends SparkFunSuite {
className = "",
resources = Seq.empty)
val newIdentifier = functionDefinition.identifier.copy(funcName = "fn4")
val renamedFunctionDefinition = functionDefinition.copy(identifier = newIdentifier)
catalog.createDatabase(dbDefinition, ignoreIfExists = false)
checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil)

View file

@ -112,9 +112,7 @@ object SortPrefixUtils {
val field = schema.head
getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
} else {
new PrefixComparator {
override def compare(prefix1: Long, prefix2: Long): Int = 0
}
(_: Long, _: Long) => 0
}
}
@ -164,12 +162,7 @@ object SortPrefixUtils {
}
}
} else {
new UnsafeExternalRowSorter.PrefixComputer {
override def computePrefix(row: InternalRow):
UnsafeExternalRowSorter.PrefixComputer.Prefix = {
emptyPrefix
}
}
_: InternalRow => emptyPrefix
}
}
}

View file

@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
@ -72,7 +72,7 @@ case class CreateDatabaseCommand(
CatalogDatabase(
databaseName,
comment.getOrElse(""),
path.map(CatalogUtils.stringToURI(_)).getOrElse(catalog.getDefaultDBPath(databaseName)),
path.map(CatalogUtils.stringToURI).getOrElse(catalog.getDefaultDBPath(databaseName)),
props),
ifNotExists)
Seq.empty[Row]
@ -352,9 +352,8 @@ case class AlterTableChangeColumnCommand(
}
// Add the comment to a column, if comment is empty, return the original column.
private def addComment(column: StructField, comment: Option[String]): StructField = {
comment.map(column.withComment(_)).getOrElse(column)
}
private def addComment(column: StructField, comment: Option[String]): StructField =
comment.map(column.withComment).getOrElse(column)
// Compare a [[StructField]] to another, return true if they have the same column
// name(by resolver) and dataType.
@ -584,14 +583,12 @@ case class AlterTableRecoverPartitionsCommand(
// It's very expensive to create a JobConf(ClassUtil.findContainingJar() is slow)
val jobConf = new JobConf(hadoopConf, this.getClass)
val pathFilter = FileInputFormat.getInputPathFilter(jobConf)
new PathFilter {
override def accept(path: Path): Boolean = {
val name = path.getName
if (name != "_SUCCESS" && name != "_temporary" && !name.startsWith(".")) {
pathFilter == null || pathFilter.accept(path)
} else {
false
}
path: Path => {
val name = path.getName
if (name != "_SUCCESS" && name != "_temporary" && !name.startsWith(".")) {
pathFilter == null || pathFilter.accept(path)
} else {
false
}
}
}

View file

@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.datasources
import java.util.Locale
import java.util.concurrent.Callable
import org.apache.hadoop.fs.Path
@ -222,23 +221,20 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
private def readDataSourceTable(table: CatalogTable): LogicalPlan = {
val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table)
val catalog = sparkSession.sessionState.catalog
catalog.getCachedPlan(qualifiedTableName, new Callable[LogicalPlan]() {
override def call(): LogicalPlan = {
val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_))
val dataSource =
DataSource(
sparkSession,
// In older version(prior to 2.1) of Spark, the table schema can be empty and should be
// inferred at runtime. We should still support it.
userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema),
partitionColumns = table.partitionColumnNames,
bucketSpec = table.bucketSpec,
className = table.provider.get,
options = table.storage.properties ++ pathOption,
catalogTable = Some(table))
LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table)
}
catalog.getCachedPlan(qualifiedTableName, () => {
val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_))
val dataSource =
DataSource(
sparkSession,
// In older version(prior to 2.1) of Spark, the table schema can be empty and should be
// inferred at runtime. We should still support it.
userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema),
partitionColumns = table.partitionColumnNames,
bucketSpec = table.bucketSpec,
className = table.provider.get,
options = table.storage.properties ++ pathOption,
catalogTable = Some(table))
LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table)
})
}
@ -484,8 +480,8 @@ object DataSourceStrategy {
// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(EmptyRow))
case expressions.In(a: Attribute, list) if list.forall(_.isInstanceOf[Literal]) =>
val hSet = list.map(_.eval(EmptyRow))
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
Some(sources.In(a.name, hSet.toArray.map(toScala)))

View file

@ -69,7 +69,7 @@ trait CheckpointFileManager {
/** List all the files in a path. */
def list(path: Path): Array[FileStatus] = {
list(path, new PathFilter { override def accept(path: Path): Boolean = true })
list(path, (_: Path) => true)
}
/** Make directory at the give path and all its parent directories as needed. */
@ -103,7 +103,7 @@ object CheckpointFileManager extends Logging {
* @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to
* overwrite the file if it already exists. It should not throw
* any exception if the file exists. However, if false, then the
* implementation must not overwrite if the file alraedy exists and
* implementation must not overwrite if the file already exists and
* must throw `FileAlreadyExistsException` in that case.
*/
def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit
@ -236,14 +236,12 @@ class FileSystemBasedCheckpointFileManager(path: Path, hadoopConf: Configuration
fs.open(path)
}
override def exists(path: Path): Boolean = {
try
return fs.getFileStatus(path) != null
catch {
case e: FileNotFoundException =>
return false
override def exists(path: Path): Boolean =
try {
fs.getFileStatus(path) != null
} catch {
case _: FileNotFoundException => false
}
}
override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = {
if (!overwriteIfPossible && fs.exists(dstPath)) {

View file

@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets.UTF_8
import scala.io.{Source => IOSource}
import scala.reflect.ClassTag
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.hadoop.fs.Path
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization
@ -169,13 +169,13 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag](
*/
private def compact(batchId: Long, logs: Array[T]): Boolean = {
val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval)
val allLogs = validBatches.map { id =>
val allLogs = validBatches.flatMap { id =>
super.get(id).getOrElse {
throw new IllegalStateException(
s"${batchIdToPath(id)} doesn't exist when compacting batch $batchId " +
s"(compactInterval: $compactInterval)")
}
}.flatten ++ logs
} ++ logs
// Return false as there is another writer.
super.add(batchId, compactLogs(allLogs).toArray)
}
@ -192,13 +192,13 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag](
if (latestId >= 0) {
try {
val logs =
getAllValidBatches(latestId, compactInterval).map { id =>
getAllValidBatches(latestId, compactInterval).flatMap { id =>
super.get(id).getOrElse {
throw new IllegalStateException(
s"${batchIdToPath(id)} doesn't exist " +
s"(latestId: $latestId, compactInterval: $compactInterval)")
}
}.flatten
}
return compactLogs(logs).toArray
} catch {
case e: IOException =>
@ -240,15 +240,13 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag](
s"min compaction batch id to delete = $minCompactionBatchId")
val expiredTime = System.currentTimeMillis() - fileCleanupDelayMs
fileManager.list(metadataPath, new PathFilter {
override def accept(path: Path): Boolean = {
try {
val batchId = getBatchIdFromFileName(path.getName)
batchId < minCompactionBatchId
} catch {
case _: NumberFormatException =>
false
}
fileManager.list(metadataPath, (path: Path) => {
try {
val batchId = getBatchIdFromFileName(path.getName)
batchId < minCompactionBatchId
} catch {
case _: NumberFormatException =>
false
}
}).foreach { f =>
if (f.getModificationTime <= expiredTime) {

View file

@ -89,19 +89,15 @@ class RateStreamTable(
override def capabilities(): util.Set[TableCapability] = Collections.emptySet()
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder {
override def build(): Scan = new Scan {
override def readSchema(): StructType = RateStreamProvider.SCHEMA
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan {
override def readSchema(): StructType = RateStreamProvider.SCHEMA
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
new RateStreamMicroBatchStream(
rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation)
}
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream =
new RateStreamMicroBatchStream(
rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation)
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
new RateStreamContinuousStream(rowsPerSecond, numPartitions)
}
}
override def toContinuousStream(checkpointLocation: String): ContinuousStream =
new RateStreamContinuousStream(rowsPerSecond, numPartitions)
}
}

View file

@ -130,27 +130,24 @@ class TextSocketMicroBatchStream(host: String, port: Int, numPartitions: Int)
slices.map(TextSocketInputPartition)
}
override def createReaderFactory(): PartitionReaderFactory = {
new PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val slice = partition.asInstanceOf[TextSocketInputPartition].slice
new PartitionReader[InternalRow] {
private var currentIdx = -1
override def createReaderFactory(): PartitionReaderFactory =
(partition: InputPartition) => {
val slice = partition.asInstanceOf[TextSocketInputPartition].slice
new PartitionReader[InternalRow] {
private var currentIdx = -1
override def next(): Boolean = {
currentIdx += 1
currentIdx < slice.size
}
override def get(): InternalRow = {
InternalRow(slice(currentIdx)._1, slice(currentIdx)._2)
}
override def close(): Unit = {}
override def next(): Boolean = {
currentIdx += 1
currentIdx < slice.size
}
override def get(): InternalRow = {
InternalRow(slice(currentIdx)._1, slice(currentIdx)._2)
}
override def close(): Unit = {}
}
}
}
override def commit(end: Offset): Unit = synchronized {
val newOffset = LongOffset.convert(end).getOrElse(

View file

@ -81,17 +81,15 @@ class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimest
override def capabilities(): util.Set[TableCapability] = Collections.emptySet()
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder {
override def build(): Scan = new Scan {
override def readSchema(): StructType = schema()
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan {
override def readSchema(): StructType = schema()
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
new TextSocketMicroBatchStream(host, port, numPartitions)
}
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
new TextSocketMicroBatchStream(host, port, numPartitions)
}
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
new TextSocketContinuousStream(host, port, numPartitions, options)
}
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
new TextSocketContinuousStream(host, port, numPartitions, options)
}
}
}

View file

@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.ui
import java.util.{Date, NoSuchElementException}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.Function
import scala.collection.JavaConverters._
@ -196,7 +195,7 @@ class SQLAppStatusListener(
// Check the execution again for whether the aggregated metrics data has been calculated.
// This can happen if the UI is requesting this data, and the onExecutionEnd handler is
// running at the same time. The metrics calculated for the UI can be innacurate in that
// running at the same time. The metrics calculated for the UI can be inaccurate in that
// case, since the onExecutionEnd handler will clean up tracked stage metrics.
if (exec.metricsValues != null) {
exec.metricsValues
@ -328,9 +327,7 @@ class SQLAppStatusListener(
private def getOrCreateExecution(executionId: Long): LiveExecutionData = {
liveExecutions.computeIfAbsent(executionId,
new Function[Long, LiveExecutionData]() {
override def apply(key: Long): LiveExecutionData = new LiveExecutionData(executionId)
})
(_: Long) => new LiveExecutionData(executionId))
}
private def update(exec: LiveExecutionData, force: Boolean = false): Unit = {

View file

@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.CacheManager
import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab}
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.status.ElementTrackingStore
import org.apache.spark.util.{MutableURLClassLoader, Utils}
import org.apache.spark.util.Utils
/**
@ -146,11 +146,7 @@ private[sql] class SharedState(
val wrapped = new ExternalCatalogWithListener(externalCatalog)
// Make sure we propagate external catalog events to the spark listener bus
wrapped.addListener(new ExternalCatalogEventListener {
override def onEvent(event: ExternalCatalogEvent): Unit = {
sparkContext.listenerBus.post(event)
}
})
wrapped.addListener((event: ExternalCatalogEvent) => sparkContext.listenerBus.post(event))
wrapped
}

View file

@ -1195,11 +1195,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
GroupedRoutes("a", "c", Seq(Route("a", "c", 2)))
)
implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] = new Ordering[GroupedRoutes] {
override def compare(x: GroupedRoutes, y: GroupedRoutes): Int = {
x.toString.compareTo(y.toString)
}
}
implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] =
(x: GroupedRoutes, y: GroupedRoutes) => x.toString.compareTo(y.toString)
checkDatasetUnorderly(grped, expected: _*)
}

View file

@ -120,10 +120,7 @@ class QueryExecutionSuite extends SharedSQLContext {
}
test("toString() exception/error handling") {
spark.experimental.extraStrategies = Seq(
new SparkStrategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = Nil
})
spark.experimental.extraStrategies = Seq[SparkStrategy]((_: LogicalPlan) => Nil)
def qe: QueryExecution = new QueryExecution(spark, OneRowRelation())
@ -131,19 +128,13 @@ class QueryExecutionSuite extends SharedSQLContext {
assert(qe.toString.contains("OneRowRelation"))
// Throw an AnalysisException - this should be captured.
spark.experimental.extraStrategies = Seq(
new SparkStrategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] =
throw new AnalysisException("exception")
})
spark.experimental.extraStrategies = Seq[SparkStrategy](
(_: LogicalPlan) => throw new AnalysisException("exception"))
assert(qe.toString.contains("org.apache.spark.sql.AnalysisException"))
// Throw an Error - this should not be captured.
spark.experimental.extraStrategies = Seq(
new SparkStrategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] =
throw new Error("error")
})
spark.experimental.extraStrategies = Seq[SparkStrategy](
(_: LogicalPlan) => throw new Error("error"))
val error = intercept[Error](qe.toString)
assert(error.getMessage.contains("error"))
}

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.benchmark
import java.util.{Arrays, Comparator}
import java.util.Arrays
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.unsafe.array.LongArray
@ -40,14 +40,9 @@ object SortBenchmark extends BenchmarkBase {
private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
override def compare(
r1: RecordPointerAndKeyPrefix,
r2: RecordPointerAndKeyPrefix): Int = {
refCmp.compare(r1.keyPrefix, r2.keyPrefix)
}
})
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(buf, lo, hi,
(r1: RecordPointerAndKeyPrefix, r2: RecordPointerAndKeyPrefix) =>
refCmp.compare(r1.keyPrefix, r2.keyPrefix))
}
private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = {

View file

@ -354,11 +354,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
val listener = new StreamingQueryListener {
override def onQueryStarted(event: QueryStartedEvent): Unit = {
// Note: this assumes there is only one query active in the `testStream` method.
Thread.currentThread.setUncaughtExceptionHandler(new UncaughtExceptionHandler {
override def uncaughtException(t: Thread, e: Throwable): Unit = {
streamThreadDeathCause = e
}
})
Thread.currentThread.setUncaughtExceptionHandler(
(_: Thread, e: Throwable) => streamThreadDeathCause = e)
}
override def onQueryProgress(event: QueryProgressEvent): Unit = {}

View file

@ -33,7 +33,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.{ProcessingTime => DeprecatedProcessingTime, _}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamingQueryException, StreamTest}
import org.apache.spark.sql.streaming.Trigger._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@ -104,9 +104,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
LastOptions.parameters = parameters
LastOptions.partitionColumns = partitionColumns
LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns, outputMode)
new Sink {
override def addBatch(batchId: Long, data: DataFrame): Unit = {}
}
(_: Long, _: DataFrame) => {}
}
}

View file

@ -60,11 +60,7 @@ class BlockingSource extends StreamSourceProvider with StreamSinkProvider {
spark: SQLContext,
parameters: Map[String, String],
partitionColumns: Seq[String],
outputMode: OutputMode): Sink = {
new Sink {
override def addBatch(batchId: Long, data: DataFrame): Unit = {}
}
}
outputMode: OutputMode): Sink = (_: Long, _: DataFrame) => {}
}
object BlockingSource {

View file

@ -29,7 +29,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor}
import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils}
import org.apache.hadoop.hive.common.HiveInterruptUtils
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.exec.Utilities
@ -65,16 +65,14 @@ private[hive] object SparkSQLCLIDriver extends Logging {
* a command is being processed by the current thread.
*/
def installSignalHandler() {
HiveInterruptUtils.add(new HiveInterruptCallback {
override def interrupt() {
// Handle remote execution mode
if (SparkSQLEnv.sparkContext != null) {
SparkSQLEnv.sparkContext.cancelAllJobs()
} else {
if (transport != null) {
// Force closing of TCP connection upon session termination
transport.getSocket.close()
}
HiveInterruptUtils.add(() => {
// Handle remote execution mode
if (SparkSQLEnv.sparkContext != null) {
SparkSQLEnv.sparkContext.cancelAllJobs()
} else {
if (transport != null) {
// Force closing of TCP connection upon session termination
transport.getSocket.close()
}
}
})
@ -208,7 +206,7 @@ private[hive] object SparkSQLCLIDriver extends Logging {
reader.setBellEnabled(false)
reader.setExpandEvents(false)
// reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true)))
CliDriver.getCommandCompleter.foreach((e) => reader.addCompleter(e))
CliDriver.getCommandCompleter.foreach(reader.addCompleter)
val historyDirectory = System.getProperty("user.home")

View file

@ -255,9 +255,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
}
def numReceivers(): Int = {
receiverInputStreams.size
}
def numReceivers(): Int = receiverInputStreams.length
/** Register a receiver */
private def registerReceiver(
@ -516,14 +514,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
context.reply(successful)
case AddBlock(receivedBlockInfo) =>
if (WriteAheadLogUtils.isBatchingEnabled(ssc.conf, isDriver = true)) {
walBatchingThreadPool.execute(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
if (active) {
context.reply(addBlock(receivedBlockInfo))
} else {
context.sendFailure(
new IllegalStateException("ReceiverTracker RpcEndpoint already shut down."))
}
walBatchingThreadPool.execute(() => Utils.tryLogNonFatalError {
if (active) {
context.reply(addBlock(receivedBlockInfo))
} else {
context.sendFailure(
new IllegalStateException("ReceiverTracker RpcEndpoint already shut down."))
}
})
} else {

View file

@ -135,18 +135,16 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
/** Start the actual log writer on a separate thread. */
private def startBatchedWriterThread(): Thread = {
val thread = new Thread(new Runnable {
override def run(): Unit = {
while (active.get()) {
try {
flushRecords()
} catch {
case NonFatal(e) =>
logWarning("Encountered exception in Batched Writer Thread.", e)
}
val thread = new Thread(() => {
while (active.get()) {
try {
flushRecords()
} catch {
case NonFatal(e) =>
logWarning("Encountered exception in Batched Writer Thread.", e)
}
logInfo("BatchedWriteAheadLog Writer thread exiting.")
}
logInfo("BatchedWriteAheadLog Writer thread exiting.")
}, "BatchedWriteAheadLog Writer")
thread.setDaemon(true)
thread.start()