[SPARK-27323][CORE][SQL][STREAMING] Use Single-Abstract-Method support in Scala 2.12 to simplify code
## What changes were proposed in this pull request? Use Single Abstract Method syntax where possible (and minor related cleanup). Comments below. No logic should change here. ## How was this patch tested? Existing tests. Closes #24241 from srowen/SPARK-27323. Authored-by: Sean Owen <sean.owen@databricks.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
d575a453db
commit
d4420b455a
|
@ -19,7 +19,7 @@ package org.apache.spark
|
|||
|
||||
import java.util.{Timer, TimerTask}
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.function.{Consumer, Function}
|
||||
import java.util.function.Consumer
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
|
@ -202,10 +202,8 @@ private[spark] class BarrierCoordinator(
|
|||
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
|
||||
// Get or init the ContextBarrierState correspond to the stage attempt.
|
||||
val barrierId = ContextBarrierId(stageId, stageAttemptId)
|
||||
states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] {
|
||||
override def apply(key: ContextBarrierId): ContextBarrierState =
|
||||
new ContextBarrierState(key, numTasks)
|
||||
})
|
||||
states.computeIfAbsent(barrierId,
|
||||
(key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
|
||||
val barrierState = states.get(barrierId)
|
||||
|
||||
barrierState.handleRequest(context, request)
|
||||
|
|
|
@ -123,9 +123,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
|
|||
cleaningThread.setDaemon(true)
|
||||
cleaningThread.setName("Spark Context Cleaner")
|
||||
cleaningThread.start()
|
||||
periodicGCService.scheduleAtFixedRate(new Runnable {
|
||||
override def run(): Unit = System.gc()
|
||||
}, periodicGCInterval, periodicGCInterval, TimeUnit.SECONDS)
|
||||
periodicGCService.scheduleAtFixedRate(() => System.gc(),
|
||||
periodicGCInterval, periodicGCInterval, TimeUnit.SECONDS)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -98,11 +98,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
|
|||
private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread")
|
||||
|
||||
override def onStart(): Unit = {
|
||||
timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable {
|
||||
override def run(): Unit = Utils.tryLogNonFatalError {
|
||||
Option(self).foreach(_.ask[Boolean](ExpireDeadHosts))
|
||||
}
|
||||
}, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
|
||||
timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(
|
||||
() => Utils.tryLogNonFatalError { Option(self).foreach(_.ask[Boolean](ExpireDeadHosts)) },
|
||||
0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
|
||||
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
|
||||
|
|
|
@ -62,9 +62,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
|
|||
|
||||
@transient private lazy val reader: ConfigReader = {
|
||||
val _reader = new ConfigReader(new SparkConfigProvider(settings))
|
||||
_reader.bindEnv(new ConfigProvider {
|
||||
override def get(key: String): Option[String] = Option(getenv(key))
|
||||
})
|
||||
_reader.bindEnv((key: String) => Option(getenv(key)))
|
||||
_reader
|
||||
}
|
||||
|
||||
|
@ -392,7 +390,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
|
|||
|
||||
/** Get an optional value, applying variable substitution. */
|
||||
private[spark] def getWithSubstitution(key: String): Option[String] = {
|
||||
getOption(key).map(reader.substitute(_))
|
||||
getOption(key).map(reader.substitute)
|
||||
}
|
||||
|
||||
/** Get all parameters as a list of pairs */
|
||||
|
|
|
@ -60,11 +60,7 @@ object PythonRunner {
|
|||
.javaAddress(localhost)
|
||||
.callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
|
||||
.build()
|
||||
val thread = new Thread(new Runnable() {
|
||||
override def run(): Unit = Utils.logUncaughtExceptions {
|
||||
gatewayServer.start()
|
||||
}
|
||||
})
|
||||
val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.start() })
|
||||
thread.setName("py4j-gateway-init")
|
||||
thread.setDaemon(true)
|
||||
thread.start()
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.deploy
|
|||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException}
|
||||
import java.security.PrivilegedExceptionAction
|
||||
import java.text.DateFormat
|
||||
import java.util.{Arrays, Comparator, Date, Locale}
|
||||
import java.util.{Arrays, Date, Locale}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.immutable.Map
|
||||
|
@ -270,11 +270,8 @@ private[spark] class SparkHadoopUtil extends Logging {
|
|||
name.startsWith(prefix) && !name.endsWith(exclusionSuffix)
|
||||
}
|
||||
})
|
||||
Arrays.sort(fileStatuses, new Comparator[FileStatus] {
|
||||
override def compare(o1: FileStatus, o2: FileStatus): Int = {
|
||||
Longs.compare(o1.getModificationTime, o2.getModificationTime)
|
||||
}
|
||||
})
|
||||
Arrays.sort(fileStatuses, (o1: FileStatus, o2: FileStatus) =>
|
||||
Longs.compare(o1.getModificationTime, o2.getModificationTime))
|
||||
fileStatuses
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
|
@ -465,7 +462,7 @@ private[spark] object SparkHadoopUtil {
|
|||
// scalastyle:on line.size.limit
|
||||
def createNonECFile(fs: FileSystem, path: Path): FSDataOutputStream = {
|
||||
try {
|
||||
// Use reflection as this uses apis only avialable in hadoop 3
|
||||
// Use reflection as this uses APIs only available in Hadoop 3
|
||||
val builderMethod = fs.getClass().getMethod("createFile", classOf[Path])
|
||||
// the builder api does not resolve relative paths, nor does it create parent dirs, while
|
||||
// the old api does.
|
||||
|
|
|
@ -186,13 +186,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
|
|||
* Return a runnable that performs the given operation on the event logs.
|
||||
* This operation is expected to be executed periodically.
|
||||
*/
|
||||
private def getRunner(operateFun: () => Unit): Runnable = {
|
||||
new Runnable() {
|
||||
override def run(): Unit = Utils.tryOrExit {
|
||||
operateFun()
|
||||
}
|
||||
}
|
||||
}
|
||||
private def getRunner(operateFun: () => Unit): Runnable =
|
||||
() => Utils.tryOrExit { operateFun() }
|
||||
|
||||
/**
|
||||
* Fixed size thread pool to fetch and parse log files.
|
||||
|
@ -221,29 +216,25 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
|
|||
// Cannot probe anything while the FS is in safe mode, so spawn a new thread that will wait
|
||||
// for the FS to leave safe mode before enabling polling. This allows the main history server
|
||||
// UI to be shown (so that the user can see the HDFS status).
|
||||
val initThread = new Thread(new Runnable() {
|
||||
override def run(): Unit = {
|
||||
try {
|
||||
while (isFsInSafeMode()) {
|
||||
logInfo("HDFS is still in safe mode. Waiting...")
|
||||
val deadline = clock.getTimeMillis() +
|
||||
TimeUnit.SECONDS.toMillis(SAFEMODE_CHECK_INTERVAL_S)
|
||||
clock.waitTillTime(deadline)
|
||||
}
|
||||
startPolling()
|
||||
} catch {
|
||||
case _: InterruptedException =>
|
||||
val initThread = new Thread(() => {
|
||||
try {
|
||||
while (isFsInSafeMode()) {
|
||||
logInfo("HDFS is still in safe mode. Waiting...")
|
||||
val deadline = clock.getTimeMillis() +
|
||||
TimeUnit.SECONDS.toMillis(SAFEMODE_CHECK_INTERVAL_S)
|
||||
clock.waitTillTime(deadline)
|
||||
}
|
||||
startPolling()
|
||||
} catch {
|
||||
case _: InterruptedException =>
|
||||
}
|
||||
})
|
||||
initThread.setDaemon(true)
|
||||
initThread.setName(s"${getClass().getSimpleName()}-init")
|
||||
initThread.setUncaughtExceptionHandler(errorHandler.getOrElse(
|
||||
new Thread.UncaughtExceptionHandler() {
|
||||
override def uncaughtException(t: Thread, e: Throwable): Unit = {
|
||||
logError("Error initializing FsHistoryProvider.", e)
|
||||
System.exit(1)
|
||||
}
|
||||
(_: Thread, e: Throwable) => {
|
||||
logError("Error initializing FsHistoryProvider.", e)
|
||||
System.exit(1)
|
||||
}))
|
||||
initThread.start()
|
||||
initThread
|
||||
|
@ -517,9 +508,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
|
|||
|
||||
val tasks = updated.flatMap { entry =>
|
||||
try {
|
||||
val task: Future[Unit] = replayExecutor.submit(new Runnable {
|
||||
override def run(): Unit = mergeApplicationListing(entry, newLastScanTime, true)
|
||||
}, Unit)
|
||||
val task: Future[Unit] = replayExecutor.submit(
|
||||
() => mergeApplicationListing(entry, newLastScanTime, true))
|
||||
Some(task -> entry.getPath)
|
||||
} catch {
|
||||
// let the iteration over the updated entries break, since an exception on
|
||||
|
|
|
@ -150,11 +150,9 @@ private[deploy] class Master(
|
|||
logInfo(s"Spark Master is acting as a reverse proxy. Master, Workers and " +
|
||||
s"Applications UIs are available at $masterWebUiUrl")
|
||||
}
|
||||
checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable {
|
||||
override def run(): Unit = Utils.tryLogNonFatalError {
|
||||
self.send(CheckForWorkerTimeOut)
|
||||
}
|
||||
}, 0, workerTimeoutMs, TimeUnit.MILLISECONDS)
|
||||
checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(
|
||||
() => Utils.tryLogNonFatalError { self.send(CheckForWorkerTimeOut) },
|
||||
0, workerTimeoutMs, TimeUnit.MILLISECONDS)
|
||||
|
||||
if (restServerEnabled) {
|
||||
val port = conf.get(MASTER_REST_SERVER_PORT)
|
||||
|
|
|
@ -325,11 +325,9 @@ private[deploy] class Worker(
|
|||
if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
|
||||
registrationRetryTimer.foreach(_.cancel(true))
|
||||
registrationRetryTimer = Some(
|
||||
forwardMessageScheduler.scheduleAtFixedRate(new Runnable {
|
||||
override def run(): Unit = Utils.tryLogNonFatalError {
|
||||
self.send(ReregisterWithMaster)
|
||||
}
|
||||
}, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS,
|
||||
forwardMessageScheduler.scheduleAtFixedRate(
|
||||
() => Utils.tryLogNonFatalError { self.send(ReregisterWithMaster) },
|
||||
PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS,
|
||||
PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS,
|
||||
TimeUnit.SECONDS))
|
||||
}
|
||||
|
@ -341,7 +339,7 @@ private[deploy] class Worker(
|
|||
}
|
||||
|
||||
/**
|
||||
* Cancel last registeration retry, or do nothing if no retry
|
||||
* Cancel last registration retry, or do nothing if no retry
|
||||
*/
|
||||
private def cancelLastRegistrationRetry(): Unit = {
|
||||
if (registerMasterFutures != null) {
|
||||
|
@ -361,11 +359,7 @@ private[deploy] class Worker(
|
|||
registerMasterFutures = tryRegisterAllMasters()
|
||||
connectionAttemptCount = 0
|
||||
registrationRetryTimer = Some(forwardMessageScheduler.scheduleAtFixedRate(
|
||||
new Runnable {
|
||||
override def run(): Unit = Utils.tryLogNonFatalError {
|
||||
Option(self).foreach(_.send(ReregisterWithMaster))
|
||||
}
|
||||
},
|
||||
() => Utils.tryLogNonFatalError { Option(self).foreach(_.send(ReregisterWithMaster)) },
|
||||
INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
|
||||
INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
|
||||
TimeUnit.SECONDS))
|
||||
|
@ -407,19 +401,15 @@ private[deploy] class Worker(
|
|||
}
|
||||
registered = true
|
||||
changeMaster(masterRef, masterWebUiUrl, masterAddress)
|
||||
forwardMessageScheduler.scheduleAtFixedRate(new Runnable {
|
||||
override def run(): Unit = Utils.tryLogNonFatalError {
|
||||
self.send(SendHeartbeat)
|
||||
}
|
||||
}, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)
|
||||
forwardMessageScheduler.scheduleAtFixedRate(
|
||||
() => Utils.tryLogNonFatalError { self.send(SendHeartbeat) },
|
||||
0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)
|
||||
if (CLEANUP_ENABLED) {
|
||||
logInfo(
|
||||
s"Worker cleanup enabled; old application directories will be deleted in: $workDir")
|
||||
forwardMessageScheduler.scheduleAtFixedRate(new Runnable {
|
||||
override def run(): Unit = Utils.tryLogNonFatalError {
|
||||
self.send(WorkDirCleanup)
|
||||
}
|
||||
}, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
|
||||
forwardMessageScheduler.scheduleAtFixedRate(
|
||||
() => Utils.tryLogNonFatalError { self.send(WorkDirCleanup) },
|
||||
CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
|
||||
val execs = executors.values.map { e =>
|
||||
|
@ -568,7 +558,7 @@ private[deploy] class Worker(
|
|||
}
|
||||
}
|
||||
|
||||
case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
|
||||
case executorStateChanged: ExecutorStateChanged =>
|
||||
handleExecutorStateChanged(executorStateChanged)
|
||||
|
||||
case KillExecutor(masterUrl, appId, execId) =>
|
||||
|
@ -632,7 +622,7 @@ private[deploy] class Worker(
|
|||
|
||||
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
|
||||
if (master.exists(_.address == remoteAddress) ||
|
||||
masterAddressToConnect.exists(_ == remoteAddress)) {
|
||||
masterAddressToConnect.contains(remoteAddress)) {
|
||||
logInfo(s"$remoteAddress Disassociated !")
|
||||
masterDisconnected()
|
||||
}
|
||||
|
@ -815,7 +805,7 @@ private[deploy] object Worker extends Logging {
|
|||
val systemName = SYSTEM_NAME + workerNumber.map(_.toString).getOrElse("")
|
||||
val securityMgr = new SecurityManager(conf)
|
||||
val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr)
|
||||
val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_))
|
||||
val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL)
|
||||
rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory,
|
||||
masterAddresses, ENDPOINT_NAME, workDir, conf, securityMgr))
|
||||
rpcEnv
|
||||
|
|
|
@ -89,17 +89,14 @@ private[spark] class Executor(
|
|||
}
|
||||
|
||||
// Start worker thread pool
|
||||
// Use UninterruptibleThread to run tasks so that we can allow running codes without being
|
||||
// interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
|
||||
// will hang forever if some methods are interrupted.
|
||||
private val threadPool = {
|
||||
val threadFactory = new ThreadFactoryBuilder()
|
||||
.setDaemon(true)
|
||||
.setNameFormat("Executor task launch worker-%d")
|
||||
.setThreadFactory(new ThreadFactory {
|
||||
override def newThread(r: Runnable): Thread =
|
||||
// Use UninterruptibleThread to run tasks so that we can allow running codes without being
|
||||
// interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
|
||||
// will hang forever if some methods are interrupted.
|
||||
new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder
|
||||
})
|
||||
.setThreadFactory((r: Runnable) => new UninterruptibleThread(r, "unused"))
|
||||
.build()
|
||||
Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ private[spark] abstract class LauncherBackend {
|
|||
.map(_.toInt)
|
||||
val secret = conf.getOption(LauncherProtocol.CONF_LAUNCHER_SECRET)
|
||||
.orElse(sys.env.get(LauncherProtocol.ENV_LAUNCHER_SECRET))
|
||||
if (port != None && secret != None) {
|
||||
if (port.isDefined && secret.isDefined) {
|
||||
val s = new Socket(InetAddress.getLoopbackAddress(), port.get)
|
||||
connection = new BackendConnection(s)
|
||||
connection.send(new Hello(secret.get, SPARK_VERSION))
|
||||
|
@ -94,11 +94,8 @@ private[spark] abstract class LauncherBackend {
|
|||
protected def onDisconnected() : Unit = { }
|
||||
|
||||
private def fireStopRequest(): Unit = {
|
||||
val thread = LauncherBackend.threadFactory.newThread(new Runnable() {
|
||||
override def run(): Unit = Utils.tryLogNonFatalError {
|
||||
onStopRequest()
|
||||
}
|
||||
})
|
||||
val thread = LauncherBackend.threadFactory.newThread(
|
||||
() => Utils.tryLogNonFatalError { onStopRequest() })
|
||||
thread.start()
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit
|
|||
|
||||
import scala.collection.mutable
|
||||
|
||||
import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
|
||||
import com.codahale.metrics.{Metric, MetricRegistry}
|
||||
import org.eclipse.jetty.servlet.ServletContextHandler
|
||||
|
||||
import org.apache.spark.{SecurityManager, SparkConf}
|
||||
|
@ -168,9 +168,7 @@ private[spark] class MetricsSystem private (
|
|||
def removeSource(source: Source) {
|
||||
sources -= source
|
||||
val regName = buildRegistryName(source)
|
||||
registry.removeMatching(new MetricFilter {
|
||||
def matches(name: String, metric: Metric): Boolean = name.startsWith(regName)
|
||||
})
|
||||
registry.removeMatching((name: String, _: Metric) => name.startsWith(regName))
|
||||
}
|
||||
|
||||
private def registerSources() {
|
||||
|
|
|
@ -21,7 +21,6 @@ import java.io.NotSerializableException
|
|||
import java.util.Properties
|
||||
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.util.function.BiFunction
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.Map
|
||||
|
@ -370,9 +369,10 @@ private[spark] class DAGScheduler(
|
|||
* 2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)).
|
||||
*/
|
||||
private def checkBarrierStageWithRDDChainPattern(rdd: RDD[_], numTasksInStage: Int): Unit = {
|
||||
val predicate: RDD[_] => Boolean = (r =>
|
||||
r.getNumPartitions == numTasksInStage && r.dependencies.filter(_.rdd.isBarrier()).size <= 1)
|
||||
if (rdd.isBarrier() && !traverseParentRDDsWithinStage(rdd, predicate)) {
|
||||
if (rdd.isBarrier() &&
|
||||
!traverseParentRDDsWithinStage(rdd, (r: RDD[_]) =>
|
||||
r.getNumPartitions == numTasksInStage &&
|
||||
r.dependencies.count(_.rdd.isBarrier()) <= 1)) {
|
||||
throw new BarrierJobUnsupportedRDDChainException
|
||||
}
|
||||
}
|
||||
|
@ -692,7 +692,7 @@ private[spark] class DAGScheduler(
|
|||
}
|
||||
|
||||
val jobId = nextJobId.getAndIncrement()
|
||||
if (partitions.size == 0) {
|
||||
if (partitions.isEmpty) {
|
||||
val time = clock.getTimeMillis()
|
||||
listenerBus.post(
|
||||
SparkListenerJobStart(jobId, time, Seq[StageInfo](), properties))
|
||||
|
@ -702,9 +702,9 @@ private[spark] class DAGScheduler(
|
|||
return new JobWaiter[U](this, jobId, 0, resultHandler)
|
||||
}
|
||||
|
||||
assert(partitions.size > 0)
|
||||
assert(partitions.nonEmpty)
|
||||
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
|
||||
val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler)
|
||||
eventProcessLoop.post(JobSubmitted(
|
||||
jobId, rdd, func2, partitions.toArray, callSite, waiter,
|
||||
SerializationUtils.clone(properties)))
|
||||
|
@ -767,9 +767,8 @@ private[spark] class DAGScheduler(
|
|||
callSite: CallSite,
|
||||
timeout: Long,
|
||||
properties: Properties): PartialResult[R] = {
|
||||
val partitions = (0 until rdd.partitions.length).toArray
|
||||
val jobId = nextJobId.getAndIncrement()
|
||||
if (partitions.isEmpty) {
|
||||
if (rdd.partitions.isEmpty) {
|
||||
// Return immediately if the job is running 0 tasks
|
||||
val time = clock.getTimeMillis()
|
||||
listenerBus.post(SparkListenerJobStart(jobId, time, Seq[StageInfo](), properties))
|
||||
|
@ -779,7 +778,8 @@ private[spark] class DAGScheduler(
|
|||
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
|
||||
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||
eventProcessLoop.post(JobSubmitted(
|
||||
jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties)))
|
||||
jobId, rdd, func2, rdd.partitions.indices.toArray, callSite, listener,
|
||||
SerializationUtils.clone(properties)))
|
||||
listener.awaitResult() // Will throw an exception if the job fails
|
||||
}
|
||||
|
||||
|
@ -812,7 +812,9 @@ private[spark] class DAGScheduler(
|
|||
// This makes it easier to avoid race conditions between the user code and the map output
|
||||
// tracker that might result if we told the user the stage had finished, but then they queries
|
||||
// the map output tracker and some node failures had caused the output statistics to be lost.
|
||||
val waiter = new JobWaiter(this, jobId, 1, (i: Int, r: MapOutputStatistics) => callback(r))
|
||||
val waiter = new JobWaiter[MapOutputStatistics](
|
||||
this, jobId, 1,
|
||||
(_: Int, r: MapOutputStatistics) => callback(r))
|
||||
eventProcessLoop.post(MapStageSubmitted(
|
||||
jobId, dependency, callSite, waiter, SerializationUtils.clone(properties)))
|
||||
waiter
|
||||
|
@ -870,7 +872,7 @@ private[spark] class DAGScheduler(
|
|||
* the last fetch failure.
|
||||
*/
|
||||
private[scheduler] def resubmitFailedStages() {
|
||||
if (failedStages.size > 0) {
|
||||
if (failedStages.nonEmpty) {
|
||||
// Failed stages may be removed by job cancellation, so failed might be empty even if
|
||||
// the ResubmitFailedStages event has been scheduled.
|
||||
logInfo("Resubmitting failed stages")
|
||||
|
@ -982,9 +984,7 @@ private[spark] class DAGScheduler(
|
|||
"than the total number of slots in the cluster currently.")
|
||||
// If jobId doesn't exist in the map, Scala coverts its value null to 0: Int automatically.
|
||||
val numCheckFailures = barrierJobIdToNumTasksCheckFailures.compute(jobId,
|
||||
new BiFunction[Int, Int, Int] {
|
||||
override def apply(key: Int, value: Int): Int = value + 1
|
||||
})
|
||||
(_: Int, value: Int) => value + 1)
|
||||
if (numCheckFailures <= maxFailureNumTasksCheck) {
|
||||
messageScheduler.schedule(
|
||||
new Runnable {
|
||||
|
@ -1227,7 +1227,7 @@ private[spark] class DAGScheduler(
|
|||
return
|
||||
}
|
||||
|
||||
if (tasks.size > 0) {
|
||||
if (tasks.nonEmpty) {
|
||||
logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
|
||||
s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
|
||||
taskScheduler.submitTasks(new TaskSet(
|
||||
|
@ -1942,7 +1942,7 @@ private[spark] class DAGScheduler(
|
|||
job: ActiveJob,
|
||||
failureReason: String,
|
||||
exception: Option[Throwable] = None): Unit = {
|
||||
val error = new SparkException(failureReason, exception.getOrElse(null))
|
||||
val error = new SparkException(failureReason, exception.orNull)
|
||||
var ableToCancelStages = true
|
||||
|
||||
// Cancel all independent, running stages.
|
||||
|
|
|
@ -80,7 +80,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
|
|||
logDebug("Fetching indirect task result for TID %s".format(tid))
|
||||
scheduler.handleTaskGettingResult(taskSetManager, tid)
|
||||
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
|
||||
if (!serializedTaskResult.isDefined) {
|
||||
if (serializedTaskResult.isEmpty) {
|
||||
/* We won't be able to get the task result if the machine that ran the task failed
|
||||
* between when the task ended and when we tried to fetch the result, or if the
|
||||
* block manager had to flush the result. */
|
||||
|
@ -128,27 +128,25 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
|
|||
serializedData: ByteBuffer) {
|
||||
var reason : TaskFailedReason = UnknownReason
|
||||
try {
|
||||
getTaskResultExecutor.execute(new Runnable {
|
||||
override def run(): Unit = Utils.logUncaughtExceptions {
|
||||
val loader = Utils.getContextOrSparkClassLoader
|
||||
try {
|
||||
if (serializedData != null && serializedData.limit() > 0) {
|
||||
reason = serializer.get().deserialize[TaskFailedReason](
|
||||
serializedData, loader)
|
||||
}
|
||||
} catch {
|
||||
case cnd: ClassNotFoundException =>
|
||||
// Log an error but keep going here -- the task failed, so not catastrophic
|
||||
// if we can't deserialize the reason.
|
||||
logError(
|
||||
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
|
||||
case ex: Exception => // No-op
|
||||
} finally {
|
||||
// If there's an error while deserializing the TaskEndReason, this Runnable
|
||||
// will die. Still tell the scheduler about the task failure, to avoid a hang
|
||||
// where the scheduler thinks the task is still running.
|
||||
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
|
||||
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
|
||||
val loader = Utils.getContextOrSparkClassLoader
|
||||
try {
|
||||
if (serializedData != null && serializedData.limit() > 0) {
|
||||
reason = serializer.get().deserialize[TaskFailedReason](
|
||||
serializedData, loader)
|
||||
}
|
||||
} catch {
|
||||
case _: ClassNotFoundException =>
|
||||
// Log an error but keep going here -- the task failed, so not catastrophic
|
||||
// if we can't deserialize the reason.
|
||||
logError(
|
||||
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
|
||||
case _: Exception => // No-op
|
||||
} finally {
|
||||
// If there's an error while deserializing the TaskEndReason, this Runnable
|
||||
// will die. Still tell the scheduler about the task failure, to avoid a hang
|
||||
// where the scheduler thinks the task is still running.
|
||||
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
|
||||
}
|
||||
})
|
||||
} catch {
|
||||
|
|
|
@ -53,7 +53,7 @@ import org.apache.spark.util.{AccumulatorV2, SystemClock, ThreadUtils, Utils}
|
|||
* we are holding a lock on ourselves. This class is called from many threads, notably:
|
||||
* * The DAGScheduler Event Loop
|
||||
* * The RPCHandler threads, responding to status updates from Executors
|
||||
* * Periodic revival of all offers from the CoarseGrainedSchedulerBackend, to accomodate delay
|
||||
* * Periodic revival of all offers from the CoarseGrainedSchedulerBackend, to accommodate delay
|
||||
* scheduling
|
||||
* * task-result-getter threads
|
||||
*/
|
||||
|
@ -194,11 +194,9 @@ private[spark] class TaskSchedulerImpl(
|
|||
|
||||
if (!isLocal && conf.get(SPECULATION_ENABLED)) {
|
||||
logInfo("Starting speculative execution thread")
|
||||
speculationScheduler.scheduleWithFixedDelay(new Runnable {
|
||||
override def run(): Unit = Utils.tryOrStopSparkContext(sc) {
|
||||
checkSpeculatableTasks()
|
||||
}
|
||||
}, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS)
|
||||
speculationScheduler.scheduleWithFixedDelay(
|
||||
() => Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() },
|
||||
SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -373,7 +371,7 @@ private[spark] class TaskSchedulerImpl(
|
|||
}
|
||||
}
|
||||
}
|
||||
return launchedTask
|
||||
launchedTask
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -527,7 +525,7 @@ private[spark] class TaskSchedulerImpl(
|
|||
|
||||
// TODO SPARK-24823 Cancel a job that contains barrier stage(s) if the barrier tasks don't get
|
||||
// launched within a configured time.
|
||||
if (tasks.size > 0) {
|
||||
if (tasks.nonEmpty) {
|
||||
hasLaunchedTask = true
|
||||
}
|
||||
return tasks
|
||||
|
|
|
@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit
|
|||
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
|
||||
import javax.annotation.concurrent.GuardedBy
|
||||
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
|
||||
import scala.collection.mutable.{HashMap, HashSet}
|
||||
import scala.concurrent.Future
|
||||
|
||||
import org.apache.hadoop.security.UserGroupInformation
|
||||
|
@ -133,10 +133,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
|
|||
// Periodically revive offers to allow delay scheduling to work
|
||||
val reviveIntervalMs = conf.get(SCHEDULER_REVIVE_INTERVAL).getOrElse(1000L)
|
||||
|
||||
reviveThread.scheduleAtFixedRate(new Runnable {
|
||||
override def run(): Unit = Utils.tryLogNonFatalError {
|
||||
Option(self).foreach(_.send(ReviveOffers))
|
||||
}
|
||||
reviveThread.scheduleAtFixedRate(() => Utils.tryLogNonFatalError {
|
||||
Option(self).foreach(_.send(ReviveOffers))
|
||||
}, 0, reviveIntervalMs, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
|
||||
|
@ -268,7 +266,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
|
|||
}.toIndexedSeq
|
||||
scheduler.resourceOffers(workOffers)
|
||||
}
|
||||
if (!taskDescs.isEmpty) {
|
||||
if (taskDescs.nonEmpty) {
|
||||
launchTasks(taskDescs)
|
||||
}
|
||||
}
|
||||
|
@ -296,7 +294,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
|
|||
Seq.empty
|
||||
}
|
||||
}
|
||||
if (!taskDescs.isEmpty) {
|
||||
if (taskDescs.nonEmpty) {
|
||||
launchTasks(taskDescs)
|
||||
}
|
||||
}
|
||||
|
@ -669,7 +667,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
|
|||
}
|
||||
|
||||
val killExecutors: Boolean => Future[Boolean] =
|
||||
if (!executorsToKill.isEmpty) {
|
||||
if (executorsToKill.nonEmpty) {
|
||||
_ => doKillExecutors(executorsToKill)
|
||||
} else {
|
||||
_ => Future.successful(false)
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.apache.spark.status
|
|||
|
||||
import java.util.Date
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.function.Function
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.HashMap
|
||||
|
@ -840,11 +839,11 @@ private[spark] class AppStatusListener(
|
|||
// check if there is a new peak value for any of the executor level memory metrics,
|
||||
// while reading from the log. SparkListenerStageExecutorMetrics are only processed
|
||||
// when reading logs.
|
||||
liveExecutors.get(executorMetrics.execId)
|
||||
.orElse(deadExecutors.get(executorMetrics.execId)).map { exec =>
|
||||
if (exec.peakExecutorMetrics.compareAndUpdatePeakValues(executorMetrics.executorMetrics)) {
|
||||
update(exec, now)
|
||||
}
|
||||
liveExecutors.get(executorMetrics.execId).orElse(
|
||||
deadExecutors.get(executorMetrics.execId)).foreach { exec =>
|
||||
if (exec.peakExecutorMetrics.compareAndUpdatePeakValues(executorMetrics.executorMetrics)) {
|
||||
update(exec, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1048,9 +1047,7 @@ private[spark] class AppStatusListener(
|
|||
|
||||
private def getOrCreateStage(info: StageInfo): LiveStage = {
|
||||
val stage = liveStages.computeIfAbsent((info.stageId, info.attemptNumber),
|
||||
new Function[(Int, Int), LiveStage]() {
|
||||
override def apply(key: (Int, Int)): LiveStage = new LiveStage()
|
||||
})
|
||||
(_: (Int, Int)) => new LiveStage())
|
||||
stage.info = info
|
||||
stage
|
||||
}
|
||||
|
|
|
@ -143,12 +143,10 @@ private[spark] class ExternalSorter[K, V, C](
|
|||
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
|
||||
// non-equal keys also have this, so we need to do a later pass to find truly equal keys).
|
||||
// Note that we ignore this if no aggregator and no ordering are given.
|
||||
private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
|
||||
override def compare(a: K, b: K): Int = {
|
||||
val h1 = if (a == null) 0 else a.hashCode()
|
||||
val h2 = if (b == null) 0 else b.hashCode()
|
||||
if (h1 < h2) -1 else if (h1 == h2) 0 else 1
|
||||
}
|
||||
private val keyComparator: Comparator[K] = ordering.getOrElse((a: K, b: K) => {
|
||||
val h1 = if (a == null) 0 else a.hashCode()
|
||||
val h2 = if (b == null) 0 else b.hashCode()
|
||||
if (h1 < h2) -1 else if (h1 == h2) 0 else 1
|
||||
})
|
||||
|
||||
private def comparator: Option[Comparator[K]] = {
|
||||
|
@ -363,17 +361,15 @@ private[spark] class ExternalSorter[K, V, C](
|
|||
* Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys.
|
||||
*/
|
||||
private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
|
||||
: Iterator[Product2[K, C]] =
|
||||
{
|
||||
: Iterator[Product2[K, C]] = {
|
||||
val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
|
||||
type Iter = BufferedIterator[Product2[K, C]]
|
||||
val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
|
||||
// Use the reverse order because PriorityQueue dequeues the max
|
||||
override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1)
|
||||
})
|
||||
// Use the reverse order (compare(y,x)) because PriorityQueue dequeues the max
|
||||
val heap = new mutable.PriorityQueue[Iter]()(
|
||||
(x: Iter, y: Iter) => comparator.compare(y.head._1, x.head._1))
|
||||
heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true
|
||||
new Iterator[Product2[K, C]] {
|
||||
override def hasNext: Boolean = !heap.isEmpty
|
||||
override def hasNext: Boolean = heap.nonEmpty
|
||||
|
||||
override def next(): Product2[K, C] = {
|
||||
if (!hasNext) {
|
||||
|
@ -400,13 +396,12 @@ private[spark] class ExternalSorter[K, V, C](
|
|||
mergeCombiners: (C, C) => C,
|
||||
comparator: Comparator[K],
|
||||
totalOrder: Boolean)
|
||||
: Iterator[Product2[K, C]] =
|
||||
{
|
||||
: Iterator[Product2[K, C]] = {
|
||||
if (!totalOrder) {
|
||||
// We only have a partial ordering, e.g. comparing the keys by hash code, which means that
|
||||
// multiple distinct keys might be treated as equal by the ordering. To deal with this, we
|
||||
// need to read all keys considered equal by the ordering at once and compare them.
|
||||
new Iterator[Iterator[Product2[K, C]]] {
|
||||
val it = new Iterator[Iterator[Product2[K, C]]] {
|
||||
val sorted = mergeSort(iterators, comparator).buffered
|
||||
|
||||
// Buffers reused across elements to decrease memory allocation
|
||||
|
@ -446,7 +441,8 @@ private[spark] class ExternalSorter[K, V, C](
|
|||
// equal by the partial order; we flatten this below to get a flat iterator of (K, C).
|
||||
keys.iterator.zip(combiners.iterator)
|
||||
}
|
||||
}.flatMap(i => i)
|
||||
}
|
||||
it.flatten
|
||||
} else {
|
||||
// We have a total ordering, so the objects with the same key are sequential.
|
||||
new Iterator[Product2[K, C]] {
|
||||
|
@ -650,7 +646,7 @@ private[spark] class ExternalSorter[K, V, C](
|
|||
if (spills.isEmpty) {
|
||||
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
|
||||
// we don't even need to sort by anything other than partition ID
|
||||
if (!ordering.isDefined) {
|
||||
if (ordering.isEmpty) {
|
||||
// The user hasn't requested sorted keys, so only sort by partition ID, not key
|
||||
groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
|
||||
} else {
|
||||
|
|
|
@ -68,27 +68,20 @@ private[spark] object WritablePartitionedPairCollection {
|
|||
/**
|
||||
* A comparator for (Int, K) pairs that orders them by only their partition ID.
|
||||
*/
|
||||
def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] {
|
||||
override def compare(a: (Int, K), b: (Int, K)): Int = {
|
||||
a._1 - b._1
|
||||
}
|
||||
}
|
||||
def partitionComparator[K]: Comparator[(Int, K)] = (a: (Int, K), b: (Int, K)) => a._1 - b._1
|
||||
|
||||
/**
|
||||
* A comparator for (Int, K) pairs that orders them both by their partition ID and a key ordering.
|
||||
*/
|
||||
def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = {
|
||||
new Comparator[(Int, K)] {
|
||||
override def compare(a: (Int, K), b: (Int, K)): Int = {
|
||||
val partitionDiff = a._1 - b._1
|
||||
if (partitionDiff != 0) {
|
||||
partitionDiff
|
||||
} else {
|
||||
keyComparator.compare(a._2, b._2)
|
||||
}
|
||||
def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] =
|
||||
(a: (Int, K), b: (Int, K)) => {
|
||||
val partitionDiff = a._1 - b._1
|
||||
if (partitionDiff != 0) {
|
||||
partitionDiff
|
||||
} else {
|
||||
keyComparator.compare(a._2, b._2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -39,7 +39,7 @@ private[spark] class DriverLogger(conf: SparkConf) extends Logging {
|
|||
private val DEFAULT_LAYOUT = "%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n"
|
||||
private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
|
||||
|
||||
private var localLogFile: String = FileUtils.getFile(
|
||||
private val localLogFile: String = FileUtils.getFile(
|
||||
Utils.getLocalDir(conf),
|
||||
DriverLogger.DRIVER_LOG_DIR,
|
||||
DriverLogger.DRIVER_LOG_FILE).getAbsolutePath()
|
||||
|
@ -163,9 +163,7 @@ private[spark] class DriverLogger(conf: SparkConf) extends Logging {
|
|||
|
||||
def closeWriter(): Unit = {
|
||||
try {
|
||||
threadpool.execute(new Runnable() {
|
||||
override def run(): Unit = DfsAsyncWriter.this.close()
|
||||
})
|
||||
threadpool.execute(() => DfsAsyncWriter.this.close())
|
||||
threadpool.shutdown()
|
||||
threadpool.awaitTermination(1, TimeUnit.MINUTES)
|
||||
} catch {
|
||||
|
|
|
@ -164,10 +164,9 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
|
|||
|
||||
test("Thread safeness - SPARK-5425") {
|
||||
val executor = Executors.newSingleThreadScheduledExecutor()
|
||||
val sf = executor.scheduleAtFixedRate(new Runnable {
|
||||
override def run(): Unit =
|
||||
System.setProperty("spark.5425." + Random.nextInt(), Random.nextInt().toString)
|
||||
}, 0, 1, TimeUnit.MILLISECONDS)
|
||||
executor.scheduleAtFixedRate(
|
||||
() => System.setProperty("spark.5425." + Random.nextInt(), Random.nextInt().toString),
|
||||
0, 1, TimeUnit.MILLISECONDS)
|
||||
|
||||
try {
|
||||
val t0 = System.nanoTime()
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.eclipse.jetty.servlet.ServletContextHandler
|
|||
import org.mockito.ArgumentMatchers.any
|
||||
import org.mockito.Mockito._
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.Matchers
|
||||
import org.scalatest.mockito.MockitoSugar
|
||||
|
||||
|
@ -374,11 +373,9 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar
|
|||
when(request.getRequestURI()).thenReturn("http://localhost:18080/history/local-123/jobs/job/")
|
||||
when(request.getQueryString()).thenReturn("id=2")
|
||||
val resp = mock[HttpServletResponse]
|
||||
when(resp.encodeRedirectURL(any())).thenAnswer(new Answer[String]() {
|
||||
override def answer(invocationOnMock: InvocationOnMock): String = {
|
||||
invocationOnMock.getArguments()(0).asInstanceOf[String]
|
||||
}
|
||||
})
|
||||
when(resp.encodeRedirectURL(any())).thenAnswer { (invocationOnMock: InvocationOnMock) =>
|
||||
invocationOnMock.getArguments()(0).asInstanceOf[String]
|
||||
}
|
||||
filter.doFilter(request, resp, null)
|
||||
verify(resp).sendRedirect("http://localhost:18080/history/local-123/jobs/job/?id=2")
|
||||
}
|
||||
|
|
|
@ -33,7 +33,6 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, FSDataInputStream, Path}
|
|||
import org.apache.hadoop.hdfs.{DFSInputStream, DistributedFileSystem}
|
||||
import org.apache.hadoop.security.AccessControlException
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
import org.mockito.ArgumentMatcher
|
||||
import org.mockito.ArgumentMatchers.{any, argThat}
|
||||
import org.mockito.Mockito.{doThrow, mock, spy, verify, when}
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
@ -49,7 +48,7 @@ import org.apache.spark.io._
|
|||
import org.apache.spark.scheduler._
|
||||
import org.apache.spark.scheduler.cluster.ExecutorInfo
|
||||
import org.apache.spark.security.GroupMappingServiceProvider
|
||||
import org.apache.spark.status.{AppStatusStore, ExecutorSummaryWrapper}
|
||||
import org.apache.spark.status.AppStatusStore
|
||||
import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo}
|
||||
import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils}
|
||||
import org.apache.spark.util.logging.DriverLogger
|
||||
|
@ -1122,11 +1121,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
|
|||
SparkListenerApplicationEnd(5L))
|
||||
val mockedFs = spy(provider.fs)
|
||||
doThrow(new AccessControlException("Cannot read accessDenied file")).when(mockedFs).open(
|
||||
argThat(new ArgumentMatcher[Path]() {
|
||||
override def matches(path: Path): Boolean = {
|
||||
path.asInstanceOf[Path].getName.toLowerCase(Locale.ROOT) == "accessdenied"
|
||||
}
|
||||
}))
|
||||
argThat((path: Path) => path.getName.toLowerCase(Locale.ROOT) == "accessdenied"))
|
||||
val mockedProvider = spy(provider)
|
||||
when(mockedProvider.fs).thenReturn(mockedFs)
|
||||
updateAndCheck(mockedProvider) { list =>
|
||||
|
|
|
@ -24,7 +24,6 @@ import scala.concurrent.duration._
|
|||
import org.mockito.ArgumentMatchers.{any, anyInt}
|
||||
import org.mockito.Mockito._
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
|
||||
|
||||
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
|
||||
|
@ -57,11 +56,9 @@ class DriverRunnerTest extends SparkFunSuite {
|
|||
superviseRetry: Boolean) = {
|
||||
val runner = createDriverRunner()
|
||||
runner.setSleeper(mock(classOf[Sleeper]))
|
||||
doAnswer(new Answer[Int] {
|
||||
def answer(invocation: InvocationOnMock): Int = {
|
||||
runner.runCommandWithRetry(processBuilder, p => (), supervise = superviseRetry)
|
||||
}
|
||||
}).when(runner).prepareAndRunDriver()
|
||||
doAnswer { (_: InvocationOnMock) =>
|
||||
runner.runCommandWithRetry(processBuilder, p => (), supervise = superviseRetry)
|
||||
}.when(runner).prepareAndRunDriver()
|
||||
runner
|
||||
}
|
||||
|
||||
|
@ -120,11 +117,9 @@ class DriverRunnerTest extends SparkFunSuite {
|
|||
runner.setSleeper(sleeper)
|
||||
|
||||
val (processBuilder, process) = createProcessBuilderAndProcess()
|
||||
when(process.waitFor()).thenAnswer(new Answer[Int] {
|
||||
def answer(invocation: InvocationOnMock): Int = {
|
||||
runner.kill()
|
||||
-1
|
||||
}
|
||||
when(process.waitFor()).thenAnswer((_: InvocationOnMock) => {
|
||||
runner.kill()
|
||||
-1
|
||||
})
|
||||
|
||||
runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
|
||||
|
@ -169,11 +164,9 @@ class DriverRunnerTest extends SparkFunSuite {
|
|||
val (processBuilder, process) = createProcessBuilderAndProcess()
|
||||
val runner = createTestableDriverRunner(processBuilder, superviseRetry = true)
|
||||
|
||||
when(process.waitFor()).thenAnswer(new Answer[Int] {
|
||||
def answer(invocation: InvocationOnMock): Int = {
|
||||
runner.kill()
|
||||
-1
|
||||
}
|
||||
when(process.waitFor()).thenAnswer((_: InvocationOnMock) => {
|
||||
runner.kill()
|
||||
-1
|
||||
})
|
||||
|
||||
runner.start()
|
||||
|
|
|
@ -28,7 +28,6 @@ import org.mockito.Answers.RETURNS_SMART_NULLS
|
|||
import org.mockito.ArgumentMatchers.any
|
||||
import org.mockito.Mockito._
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.{BeforeAndAfter, Matchers}
|
||||
import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
|
||||
|
||||
|
@ -233,11 +232,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
|
|||
val conf = new SparkConf().set(config.STORAGE_CLEANUP_FILES_AFTER_EXECUTOR_EXIT, value)
|
||||
|
||||
val cleanupCalled = new AtomicBoolean(false)
|
||||
when(shuffleService.executorRemoved(any[String], any[String])).thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocations: InvocationOnMock): Unit = {
|
||||
cleanupCalled.set(true)
|
||||
}
|
||||
})
|
||||
when(shuffleService.executorRemoved(any[String], any[String])).thenAnswer(
|
||||
(_: InvocationOnMock) => cleanupCalled.set(true))
|
||||
val externalShuffleServiceSupplier = new Supplier[ExternalShuffleService] {
|
||||
override def get: ExternalShuffleService = shuffleService
|
||||
}
|
||||
|
@ -269,11 +265,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
|
|||
val appId = "app1"
|
||||
val execId = "exec1"
|
||||
val cleanupCalled = new AtomicBoolean(false)
|
||||
when(shuffleService.applicationRemoved(any[String])).thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocations: InvocationOnMock): Unit = {
|
||||
cleanupCalled.set(true)
|
||||
}
|
||||
})
|
||||
when(shuffleService.applicationRemoved(any[String])).thenAnswer(
|
||||
(_: InvocationOnMock) => cleanupCalled.set(true))
|
||||
val externalShuffleServiceSupplier = new Supplier[ExternalShuffleService] {
|
||||
override def get: ExternalShuffleService = shuffleService
|
||||
}
|
||||
|
@ -289,8 +282,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
|
|||
}
|
||||
executorDir.setLastModified(System.currentTimeMillis - (1000 * 120))
|
||||
worker.receive(WorkDirCleanup)
|
||||
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
|
||||
assert(!executorDir.exists() == true)
|
||||
eventually(timeout(1000.milliseconds), interval(10.milliseconds)) {
|
||||
assert(!executorDir.exists())
|
||||
assert(cleanupCalled.get() == dbCleanupEnabled)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -279,13 +279,10 @@ class ExecutorSuite extends SparkFunSuite
|
|||
val heartbeats = ArrayBuffer[Heartbeat]()
|
||||
val mockReceiver = mock[RpcEndpointRef]
|
||||
when(mockReceiver.askSync(any[Heartbeat], any[RpcTimeout])(any))
|
||||
.thenAnswer(new Answer[HeartbeatResponse] {
|
||||
override def answer(invocation: InvocationOnMock): HeartbeatResponse = {
|
||||
val args = invocation.getArguments()
|
||||
val mock = invocation.getMock
|
||||
heartbeats += args(0).asInstanceOf[Heartbeat]
|
||||
HeartbeatResponse(false)
|
||||
}
|
||||
.thenAnswer((invocation: InvocationOnMock) => {
|
||||
val args = invocation.getArguments()
|
||||
heartbeats += args(0).asInstanceOf[Heartbeat]
|
||||
HeartbeatResponse(false)
|
||||
})
|
||||
val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef")
|
||||
receiverRef.setAccessible(true)
|
||||
|
|
|
@ -84,11 +84,8 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
|
|||
*/
|
||||
protected def makeBadMemoryStore(mm: MemoryManager): MemoryStore = {
|
||||
val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS)
|
||||
when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())).thenAnswer(new Answer[Long] {
|
||||
override def answer(invocation: InvocationOnMock): Long = {
|
||||
throw new RuntimeException("bad memory store!")
|
||||
}
|
||||
})
|
||||
when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())).thenAnswer(
|
||||
(_: InvocationOnMock) => throw new RuntimeException("bad memory store!"))
|
||||
mm.setMemoryStore(ms)
|
||||
ms
|
||||
}
|
||||
|
@ -106,27 +103,24 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
|
|||
* records the number of bytes this is called with. This variable is expected to be cleared
|
||||
* by the test code later through [[assertEvictBlocksToFreeSpaceCalled]].
|
||||
*/
|
||||
private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Long] = {
|
||||
new Answer[Long] {
|
||||
override def answer(invocation: InvocationOnMock): Long = {
|
||||
val args = invocation.getArguments
|
||||
val numBytesToFree = args(1).asInstanceOf[Long]
|
||||
assert(numBytesToFree > 0)
|
||||
require(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED,
|
||||
"bad test: evictBlocksToFreeSpace() variable was not reset")
|
||||
evictBlocksToFreeSpaceCalled.set(numBytesToFree)
|
||||
if (numBytesToFree <= mm.storageMemoryUsed) {
|
||||
// We can evict enough blocks to fulfill the request for space
|
||||
mm.releaseStorageMemory(numBytesToFree, mm.tungstenMemoryMode)
|
||||
evictedBlocks += Tuple2(null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L))
|
||||
numBytesToFree
|
||||
} else {
|
||||
// No blocks were evicted because eviction would not free enough space.
|
||||
0L
|
||||
}
|
||||
private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Long] =
|
||||
(invocation: InvocationOnMock) => {
|
||||
val args = invocation.getArguments
|
||||
val numBytesToFree = args(1).asInstanceOf[Long]
|
||||
assert(numBytesToFree > 0)
|
||||
require(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED,
|
||||
"bad test: evictBlocksToFreeSpace() variable was not reset")
|
||||
evictBlocksToFreeSpaceCalled.set(numBytesToFree)
|
||||
if (numBytesToFree <= mm.storageMemoryUsed) {
|
||||
// We can evict enough blocks to fulfill the request for space
|
||||
mm.releaseStorageMemory(numBytesToFree, mm.tungstenMemoryMode)
|
||||
evictedBlocks += Tuple2(null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L))
|
||||
numBytesToFree
|
||||
} else {
|
||||
// No blocks were evicted because eviction would not free enough space.
|
||||
0L
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Assert that [[MemoryStore.evictBlocksToFreeSpace]] is called with the given parameters.
|
||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.spark.scheduler
|
|||
import org.mockito.ArgumentMatchers.any
|
||||
import org.mockito.Mockito.{never, verify, when}
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
import org.scalatest.mockito.MockitoSugar
|
||||
|
||||
|
@ -480,17 +479,16 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
|
|||
test("blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") {
|
||||
val allocationClientMock = mock[ExecutorAllocationClient]
|
||||
when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called"))
|
||||
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] {
|
||||
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer { (_: InvocationOnMock) =>
|
||||
// To avoid a race between blacklisting and killing, it is important that the nodeBlacklist
|
||||
// is updated before we ask the executor allocation client to kill all the executors
|
||||
// on a particular host.
|
||||
override def answer(invocation: InvocationOnMock): Boolean = {
|
||||
if (blacklist.nodeBlacklist.contains("hostA") == false) {
|
||||
throw new IllegalStateException("hostA should be on the blacklist")
|
||||
}
|
||||
if (blacklist.nodeBlacklist.contains("hostA")) {
|
||||
true
|
||||
} else {
|
||||
throw new IllegalStateException("hostA should be on the blacklist")
|
||||
}
|
||||
})
|
||||
}
|
||||
blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock)
|
||||
|
||||
// Disable auto-kill. Blacklist an executor and make sure killExecutors is not called.
|
||||
|
@ -552,17 +550,16 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
|
|||
test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") {
|
||||
val allocationClientMock = mock[ExecutorAllocationClient]
|
||||
when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called"))
|
||||
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] {
|
||||
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer { (_: InvocationOnMock) =>
|
||||
// To avoid a race between blacklisting and killing, it is important that the nodeBlacklist
|
||||
// is updated before we ask the executor allocation client to kill all the executors
|
||||
// on a particular host.
|
||||
override def answer(invocation: InvocationOnMock): Boolean = {
|
||||
if (blacklist.nodeBlacklist.contains("hostA") == false) {
|
||||
throw new IllegalStateException("hostA should be on the blacklist")
|
||||
}
|
||||
if (blacklist.nodeBlacklist.contains("hostA")) {
|
||||
true
|
||||
} else {
|
||||
throw new IllegalStateException("hostA should be on the blacklist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
conf.set(config.BLACKLIST_FETCH_FAILURE_ENABLED, true)
|
||||
blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock)
|
||||
|
|
|
@ -29,7 +29,6 @@ import org.apache.hadoop.mapreduce.TaskType
|
|||
import org.mockito.ArgumentMatchers.{any, eq => meq}
|
||||
import org.mockito.Mockito.{doAnswer, spy, times, verify}
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark._
|
||||
|
@ -98,34 +97,29 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
// Use Mockito.spy() to maintain the default infrastructure everywhere else
|
||||
val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl])
|
||||
|
||||
doAnswer(new Answer[Unit]() {
|
||||
override def answer(invoke: InvocationOnMock): Unit = {
|
||||
// Submit the tasks, then force the task scheduler to dequeue the
|
||||
// speculated task
|
||||
invoke.callRealMethod()
|
||||
mockTaskScheduler.backend.reviveOffers()
|
||||
}
|
||||
}).when(mockTaskScheduler).submitTasks(any())
|
||||
doAnswer { (invoke: InvocationOnMock) =>
|
||||
// Submit the tasks, then force the task scheduler to dequeue the
|
||||
// speculated task
|
||||
invoke.callRealMethod()
|
||||
mockTaskScheduler.backend.reviveOffers()
|
||||
}.when(mockTaskScheduler).submitTasks(any())
|
||||
|
||||
doAnswer(new Answer[TaskSetManager]() {
|
||||
override def answer(invoke: InvocationOnMock): TaskSetManager = {
|
||||
val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet]
|
||||
new TaskSetManager(mockTaskScheduler, taskSet, 4) {
|
||||
var hasDequeuedSpeculatedTask = false
|
||||
override def dequeueSpeculativeTask(
|
||||
execId: String,
|
||||
host: String,
|
||||
locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = {
|
||||
if (!hasDequeuedSpeculatedTask) {
|
||||
hasDequeuedSpeculatedTask = true
|
||||
Some((0, TaskLocality.PROCESS_LOCAL))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
doAnswer { (invoke: InvocationOnMock) =>
|
||||
val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet]
|
||||
new TaskSetManager(mockTaskScheduler, taskSet, 4) {
|
||||
private var hasDequeuedSpeculatedTask = false
|
||||
override def dequeueSpeculativeTask(execId: String,
|
||||
host: String,
|
||||
locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = {
|
||||
if (hasDequeuedSpeculatedTask) {
|
||||
None
|
||||
} else {
|
||||
hasDequeuedSpeculatedTask = true
|
||||
Some((0, TaskLocality.PROCESS_LOCAL))
|
||||
}
|
||||
}
|
||||
}
|
||||
}).when(mockTaskScheduler).createTaskSetManager(any(), any())
|
||||
}.when(mockTaskScheduler).createTaskSetManager(any(), any())
|
||||
|
||||
sc.taskScheduler = mockTaskScheduler
|
||||
val dagSchedulerWithMockTaskScheduler = new DAGScheduler(sc, mockTaskScheduler)
|
||||
|
|
|
@ -356,13 +356,9 @@ private[spark] abstract class MockBackend(
|
|||
assignedTasksWaitingToRun.nonEmpty
|
||||
}
|
||||
|
||||
override def start(): Unit = {
|
||||
reviveThread.scheduleAtFixedRate(new Runnable {
|
||||
override def run(): Unit = Utils.tryLogNonFatalError {
|
||||
reviveOffers()
|
||||
}
|
||||
}, 0, reviveIntervalMs, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
override def start(): Unit =
|
||||
reviveThread.scheduleAtFixedRate(() => Utils.tryLogNonFatalError { reviveOffers() },
|
||||
0, reviveIntervalMs, TimeUnit.MILLISECONDS)
|
||||
|
||||
override def stop(): Unit = {
|
||||
reviveThread.shutdown()
|
||||
|
|
|
@ -25,14 +25,13 @@ import scala.collection.mutable.ArrayBuffer
|
|||
import org.mockito.ArgumentMatchers.{any, anyInt, anyString}
|
||||
import org.mockito.Mockito.{mock, never, spy, times, verify, when}
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config
|
||||
import org.apache.spark.serializer.SerializerInstance
|
||||
import org.apache.spark.storage.BlockManagerId
|
||||
import org.apache.spark.util.{AccumulatorV2, ManualClock, Utils}
|
||||
import org.apache.spark.util.{AccumulatorV2, ManualClock}
|
||||
|
||||
class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
|
||||
extends DAGScheduler(sc) {
|
||||
|
@ -1190,11 +1189,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
|
|||
val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0)
|
||||
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1))
|
||||
when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).thenAnswer(
|
||||
new Answer[Unit] {
|
||||
override def answer(invocationOnMock: InvocationOnMock): Unit = {
|
||||
assert(manager.isZombie)
|
||||
}
|
||||
})
|
||||
(invocationOnMock: InvocationOnMock) => assert(manager.isZombie))
|
||||
val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
|
||||
assert(taskOption.isDefined)
|
||||
// this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon
|
||||
|
@ -1317,12 +1312,10 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
|
|||
|
||||
// Assert the task has been black listed on the executor it was last executed on.
|
||||
when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer(
|
||||
new Answer[Unit] {
|
||||
override def answer(invocationOnMock: InvocationOnMock): Unit = {
|
||||
val task: Int = invocationOnMock.getArgument(0)
|
||||
assert(taskSetManager.taskSetBlacklistHelperOpt.get.
|
||||
isExecutorBlacklistedForTask(exec, task))
|
||||
}
|
||||
(invocationOnMock: InvocationOnMock) => {
|
||||
val task: Int = invocationOnMock.getArgument(0)
|
||||
assert(taskSetManager.taskSetBlacklistHelperOpt.get.
|
||||
isExecutorBlacklistedForTask(exec, task))
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -28,7 +28,6 @@ import org.mockito.Answers.RETURNS_SMART_NULLS
|
|||
import org.mockito.ArgumentMatchers.{any, anyInt}
|
||||
import org.mockito.Mockito._
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
|
||||
import org.apache.spark._
|
||||
|
@ -69,16 +68,14 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
|
|||
when(dependency.serializer).thenReturn(new JavaSerializer(conf))
|
||||
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
|
||||
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
|
||||
doAnswer(new Answer[Void] {
|
||||
def answer(invocationOnMock: InvocationOnMock): Void = {
|
||||
val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
|
||||
if (tmp != null) {
|
||||
outputFile.delete
|
||||
tmp.renameTo(outputFile)
|
||||
}
|
||||
null
|
||||
doAnswer { (invocationOnMock: InvocationOnMock) =>
|
||||
val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
|
||||
if (tmp != null) {
|
||||
outputFile.delete
|
||||
tmp.renameTo(outputFile)
|
||||
}
|
||||
}).when(blockResolver)
|
||||
null
|
||||
}.when(blockResolver)
|
||||
.writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))
|
||||
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
|
||||
when(blockManager.getDiskWriter(
|
||||
|
@ -87,37 +84,29 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
|
|||
any[SerializerInstance],
|
||||
anyInt(),
|
||||
any[ShuffleWriteMetrics]
|
||||
)).thenAnswer(new Answer[DiskBlockObjectWriter] {
|
||||
override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = {
|
||||
val args = invocation.getArguments
|
||||
val manager = new SerializerManager(new JavaSerializer(conf), conf)
|
||||
new DiskBlockObjectWriter(
|
||||
args(1).asInstanceOf[File],
|
||||
manager,
|
||||
args(2).asInstanceOf[SerializerInstance],
|
||||
args(3).asInstanceOf[Int],
|
||||
syncWrites = false,
|
||||
args(4).asInstanceOf[ShuffleWriteMetrics],
|
||||
blockId = args(0).asInstanceOf[BlockId]
|
||||
)
|
||||
}
|
||||
)).thenAnswer((invocation: InvocationOnMock) => {
|
||||
val args = invocation.getArguments
|
||||
val manager = new SerializerManager(new JavaSerializer(conf), conf)
|
||||
new DiskBlockObjectWriter(
|
||||
args(1).asInstanceOf[File],
|
||||
manager,
|
||||
args(2).asInstanceOf[SerializerInstance],
|
||||
args(3).asInstanceOf[Int],
|
||||
syncWrites = false,
|
||||
args(4).asInstanceOf[ShuffleWriteMetrics],
|
||||
blockId = args(0).asInstanceOf[BlockId]
|
||||
)
|
||||
})
|
||||
when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
|
||||
new Answer[(TempShuffleBlockId, File)] {
|
||||
override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = {
|
||||
val blockId = new TempShuffleBlockId(UUID.randomUUID)
|
||||
val file = new File(tempDir, blockId.name)
|
||||
blockIdToFileMap.put(blockId, file)
|
||||
temporaryFilesCreated += file
|
||||
(blockId, file)
|
||||
}
|
||||
})
|
||||
when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
|
||||
new Answer[File] {
|
||||
override def answer(invocation: InvocationOnMock): File = {
|
||||
blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId])
|
||||
}
|
||||
when(diskBlockManager.createTempShuffleBlock()).thenAnswer((_: InvocationOnMock) => {
|
||||
val blockId = new TempShuffleBlockId(UUID.randomUUID)
|
||||
val file = new File(tempDir, blockId.name)
|
||||
blockIdToFileMap.put(blockId, file)
|
||||
temporaryFilesCreated += file
|
||||
(blockId, file)
|
||||
})
|
||||
when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) =>
|
||||
blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId])
|
||||
}
|
||||
}
|
||||
|
||||
override def afterEach(): Unit = {
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.mockito.Answers.RETURNS_SMART_NULLS
|
|||
import org.mockito.ArgumentMatchers.any
|
||||
import org.mockito.Mockito._
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkFunSuite}
|
||||
|
@ -48,11 +47,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
|
|||
|
||||
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
|
||||
when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
|
||||
new Answer[File] {
|
||||
override def answer(invocation: InvocationOnMock): File = {
|
||||
new File(tempDir, invocation.getArguments.head.toString)
|
||||
}
|
||||
})
|
||||
(invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString))
|
||||
}
|
||||
|
||||
override def afterEach(): Unit = {
|
||||
|
|
|
@ -24,7 +24,6 @@ import scala.reflect.ClassTag
|
|||
import org.mockito.Mockito
|
||||
import org.mockito.Mockito.atLeastOnce
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
|
||||
|
@ -59,11 +58,9 @@ class PartiallySerializedBlockSuite
|
|||
|
||||
val bbos: ChunkedByteBufferOutputStream = {
|
||||
val spy = Mockito.spy(new ChunkedByteBufferOutputStream(128, ByteBuffer.allocate))
|
||||
Mockito.doAnswer(new Answer[ChunkedByteBuffer] {
|
||||
override def answer(invocationOnMock: InvocationOnMock): ChunkedByteBuffer = {
|
||||
Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer])
|
||||
}
|
||||
}).when(spy).toChunkedByteBuffer
|
||||
Mockito.doAnswer { (invocationOnMock: InvocationOnMock) =>
|
||||
Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer])
|
||||
}.when(spy).toChunkedByteBuffer
|
||||
spy
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,6 @@ import scala.concurrent.Future
|
|||
import org.mockito.ArgumentMatchers.{any, eq => meq}
|
||||
import org.mockito.Mockito.{mock, times, verify, when}
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.PrivateMethodTester
|
||||
|
||||
import org.apache.spark.{SparkFunSuite, TaskContext}
|
||||
|
@ -50,9 +49,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
/** Creates a mock [[BlockTransferService]] that returns data from the given map. */
|
||||
private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = {
|
||||
val transfer = mock(classOf[BlockTransferService])
|
||||
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
|
||||
.thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocation: InvocationOnMock): Unit = {
|
||||
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer(
|
||||
(invocation: InvocationOnMock) => {
|
||||
val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]]
|
||||
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
|
||||
|
||||
|
@ -63,8 +61,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
transfer
|
||||
}
|
||||
|
||||
|
@ -168,8 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
|
||||
val transfer = mock(classOf[BlockTransferService])
|
||||
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
|
||||
.thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocation: InvocationOnMock): Unit = {
|
||||
.thenAnswer((invocation: InvocationOnMock) => {
|
||||
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
|
||||
Future {
|
||||
// Return the first two blocks, and wait till task completion before returning the 3rd one
|
||||
|
@ -181,8 +177,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
listener.onBlockFetchSuccess(
|
||||
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
|
||||
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
|
||||
|
@ -237,20 +232,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
|
||||
val transfer = mock(classOf[BlockTransferService])
|
||||
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
|
||||
.thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocation: InvocationOnMock): Unit = {
|
||||
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
|
||||
Future {
|
||||
// Return the first two blocks, and wait till task completion before returning the last
|
||||
listener.onBlockFetchSuccess(
|
||||
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
|
||||
listener.onBlockFetchSuccess(
|
||||
ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
|
||||
sem.acquire()
|
||||
listener.onBlockFetchSuccess(
|
||||
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
|
||||
}
|
||||
}
|
||||
.thenAnswer((invocation: InvocationOnMock) => {
|
||||
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
|
||||
Future {
|
||||
// Return the first two blocks, and wait till task completion before returning the last
|
||||
listener.onBlockFetchSuccess(
|
||||
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
|
||||
listener.onBlockFetchSuccess(
|
||||
ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
|
||||
sem.acquire()
|
||||
listener.onBlockFetchSuccess(
|
||||
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
|
||||
}
|
||||
})
|
||||
|
||||
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
|
||||
|
@ -298,8 +291,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
|
||||
val transfer = mock(classOf[BlockTransferService])
|
||||
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
|
||||
.thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocation: InvocationOnMock): Unit = {
|
||||
.thenAnswer((invocation: InvocationOnMock) => {
|
||||
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
|
||||
Future {
|
||||
// Return the first block, and then fail.
|
||||
|
@ -311,8 +303,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah"))
|
||||
sem.release()
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
|
||||
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
|
||||
|
@ -389,8 +380,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
|
||||
val transfer = mock(classOf[BlockTransferService])
|
||||
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
|
||||
.thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocation: InvocationOnMock): Unit = {
|
||||
.thenAnswer((invocation: InvocationOnMock) => {
|
||||
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
|
||||
Future {
|
||||
// Return the first block, and then fail.
|
||||
|
@ -402,8 +392,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
|
||||
sem.release()
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
|
||||
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
|
||||
|
@ -431,8 +420,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
assert(id1 === ShuffleBlockId(0, 0, 0))
|
||||
|
||||
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
|
||||
.thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocation: InvocationOnMock): Unit = {
|
||||
.thenAnswer((invocation: InvocationOnMock) => {
|
||||
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
|
||||
Future {
|
||||
// Return the first block, and then fail.
|
||||
|
@ -440,8 +428,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
|
||||
sem.release()
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// The next block is corrupt local block (the second one is corrupt and retried)
|
||||
intercept[FetchFailedException] { iterator.next() }
|
||||
|
@ -588,8 +575,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
|
||||
val transfer = mock(classOf[BlockTransferService])
|
||||
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
|
||||
.thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocation: InvocationOnMock): Unit = {
|
||||
.thenAnswer((invocation: InvocationOnMock) => {
|
||||
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
|
||||
Future {
|
||||
// Return the first block, and then fail.
|
||||
|
@ -601,8 +587,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer())
|
||||
sem.release()
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
|
||||
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
|
||||
|
@ -654,14 +639,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
val transfer = mock(classOf[BlockTransferService])
|
||||
var tempFileManager: DownloadFileManager = null
|
||||
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
|
||||
.thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocation: InvocationOnMock): Unit = {
|
||||
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
|
||||
tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager]
|
||||
Future {
|
||||
listener.onBlockFetchSuccess(
|
||||
ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))
|
||||
}
|
||||
.thenAnswer((invocation: InvocationOnMock) => {
|
||||
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
|
||||
tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager]
|
||||
Future {
|
||||
listener.onBlockFetchSuccess(
|
||||
ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -71,11 +71,9 @@ class ThreadUtilsSuite extends SparkFunSuite {
|
|||
keepAliveSeconds = 2)
|
||||
try {
|
||||
for (_ <- 1 to maxThreadNumber) {
|
||||
cachedThreadPool.execute(new Runnable {
|
||||
override def run(): Unit = {
|
||||
startThreadsLatch.countDown()
|
||||
latch.await(10, TimeUnit.SECONDS)
|
||||
}
|
||||
cachedThreadPool.execute(() => {
|
||||
startThreadsLatch.countDown()
|
||||
latch.await(10, TimeUnit.SECONDS)
|
||||
})
|
||||
}
|
||||
startThreadsLatch.await(10, TimeUnit.SECONDS)
|
||||
|
@ -84,11 +82,7 @@ class ThreadUtilsSuite extends SparkFunSuite {
|
|||
|
||||
// Submit a new task and it should be put into the queue since the thread number reaches the
|
||||
// limitation
|
||||
cachedThreadPool.execute(new Runnable {
|
||||
override def run(): Unit = {
|
||||
latch.await(10, TimeUnit.SECONDS)
|
||||
}
|
||||
})
|
||||
cachedThreadPool.execute(() => latch.await(10, TimeUnit.SECONDS))
|
||||
|
||||
assert(cachedThreadPool.getActiveCount === maxThreadNumber)
|
||||
assert(cachedThreadPool.getQueue.size === 1)
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
|
||||
package org.apache.spark.util.collection
|
||||
|
||||
import java.util.Comparator
|
||||
|
||||
import scala.collection.mutable.HashSet
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
|
@ -170,12 +168,10 @@ class AppendOnlyMapSuite extends SparkFunSuite {
|
|||
case e: IllegalStateException => fail()
|
||||
}
|
||||
|
||||
val it = map.destructiveSortedIterator(new Comparator[String] {
|
||||
def compare(key1: String, key2: String): Int = {
|
||||
val x = if (key1 != null) key1.toInt else Int.MinValue
|
||||
val y = if (key2 != null) key2.toInt else Int.MinValue
|
||||
x.compareTo(y)
|
||||
}
|
||||
val it = map.destructiveSortedIterator((key1: String, key2: String) => {
|
||||
val x = if (key1 != null) key1.toInt else Int.MinValue
|
||||
val y = if (key2 != null) key2.toInt else Int.MinValue
|
||||
x.compareTo(y)
|
||||
})
|
||||
|
||||
// Should be sorted by key
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
|
||||
package org.apache.spark.util.collection
|
||||
|
||||
import java.util.Comparator
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.util.Random
|
||||
|
||||
|
@ -111,14 +109,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
|
|||
val tmp = new Array[Long](size/2)
|
||||
val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp))
|
||||
|
||||
new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort(
|
||||
buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] {
|
||||
override def compare(
|
||||
r1: RecordPointerAndKeyPrefix,
|
||||
r2: RecordPointerAndKeyPrefix): Int = {
|
||||
PrefixComparators.LONG.compare(r1.keyPrefix, r2.keyPrefix)
|
||||
}
|
||||
})
|
||||
new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort(buf, 0, size,
|
||||
(r1: RecordPointerAndKeyPrefix, r2: RecordPointerAndKeyPrefix) =>
|
||||
PrefixComparators.LONG.compare(r1.keyPrefix, r2.keyPrefix))
|
||||
}
|
||||
|
||||
test("spilling with hash collisions") {
|
||||
|
@ -135,7 +128,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
|
|||
buffer2: ArrayBuffer[String]): ArrayBuffer[String] = buffer1 ++= buffer2
|
||||
|
||||
val agg = new Aggregator[String, String, ArrayBuffer[String]](
|
||||
createCombiner _, mergeValue _, mergeCombiners _)
|
||||
createCombiner, mergeValue, mergeCombiners)
|
||||
|
||||
val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
|
||||
context, Some(agg), None, None)
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.util.collection
|
||||
|
||||
import java.lang.{Float => JFloat}
|
||||
import java.util.{Arrays, Comparator}
|
||||
import java.util.Arrays
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
|
@ -219,10 +219,8 @@ class SorterSuite extends SparkFunSuite with Logging {
|
|||
System.arraycopy(kvTuples, 0, kvTupleArray, 0, numElements)
|
||||
}
|
||||
runExperiment("Tuple-sort using Arrays.sort()")({
|
||||
Arrays.sort(kvTupleArray, new Comparator[AnyRef] {
|
||||
override def compare(x: AnyRef, y: AnyRef): Int =
|
||||
x.asInstanceOf[(JFloat, _)]._1.compareTo(y.asInstanceOf[(JFloat, _)]._1)
|
||||
})
|
||||
Arrays.sort(kvTupleArray, (x: AnyRef, y: AnyRef) =>
|
||||
x.asInstanceOf[(JFloat, _)]._1.compareTo(y.asInstanceOf[(JFloat, _)]._1))
|
||||
}, prepareKvTupleArray)
|
||||
|
||||
// Test our Sorter where each element alternates between Float and Integer, non-primitive
|
||||
|
@ -245,9 +243,7 @@ class SorterSuite extends SparkFunSuite with Logging {
|
|||
|
||||
val sorter = new Sorter(new KVArraySortDataFormat[JFloat, AnyRef])
|
||||
runExperiment("KV-sort using Sorter")({
|
||||
sorter.sort(keyValueArray, 0, numElements, new Comparator[JFloat] {
|
||||
override def compare(x: JFloat, y: JFloat): Int = x.compareTo(y)
|
||||
})
|
||||
sorter.sort(keyValueArray, 0, numElements, (x: JFloat, y: JFloat) => x.compareTo(y))
|
||||
}, prepareKeyValueArray)
|
||||
}
|
||||
|
||||
|
@ -280,11 +276,9 @@ class SorterSuite extends SparkFunSuite with Logging {
|
|||
System.arraycopy(intObjects, 0, intObjectArray, 0, numElements)
|
||||
}
|
||||
|
||||
runExperiment("Java Arrays.sort() on non-primitive int array")({
|
||||
Arrays.sort(intObjectArray, new Comparator[Integer] {
|
||||
override def compare(x: Integer, y: Integer): Int = x.compareTo(y)
|
||||
})
|
||||
}, prepareIntObjectArray)
|
||||
runExperiment("Java Arrays.sort() on non-primitive int array")(
|
||||
Arrays.sort(intObjectArray, (x: Integer, y: Integer) => x.compareTo(y)),
|
||||
prepareIntObjectArray)
|
||||
|
||||
val intPrimitiveArray = new Array[Int](numElements)
|
||||
val prepareIntPrimitiveArray = () => {
|
||||
|
|
|
@ -69,9 +69,8 @@ class RadixSortSuite extends SparkFunSuite with Logging {
|
|||
override def sortDescending = false
|
||||
override def sortSigned = false
|
||||
override def nullsFirst = true
|
||||
override def compare(a: Long, b: Long): Int = {
|
||||
return PrefixComparators.BINARY.compare(a & 0xffffff0000L, b & 0xffffff0000L)
|
||||
}
|
||||
override def compare(a: Long, b: Long): Int =
|
||||
PrefixComparators.BINARY.compare(a & 0xffffff0000L, b & 0xffffff0000L)
|
||||
},
|
||||
2, 4, false, false, true))
|
||||
|
||||
|
@ -112,11 +111,9 @@ class RadixSortSuite extends SparkFunSuite with Logging {
|
|||
private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) {
|
||||
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
|
||||
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
|
||||
buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] {
|
||||
override def compare(
|
||||
r1: RecordPointerAndKeyPrefix,
|
||||
r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix)
|
||||
})
|
||||
buf, Ints.checkedCast(lo), Ints.checkedCast(hi),
|
||||
(r1: RecordPointerAndKeyPrefix, r2: RecordPointerAndKeyPrefix) =>
|
||||
refCmp.compare(r1.keyPrefix, r2.keyPrefix))
|
||||
}
|
||||
|
||||
private def fuzzTest(name: String)(testFn: Long => Unit): Unit = {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.{util => ju}
|
||||
import java.util.concurrent.{Executors, ThreadFactory}
|
||||
import java.util.concurrent.Executors
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
@ -52,16 +52,14 @@ private[kafka010] class KafkaOffsetReader(
|
|||
/**
|
||||
* Used to ensure execute fetch operations execute in an UninterruptibleThread
|
||||
*/
|
||||
val kafkaReaderThread = Executors.newSingleThreadExecutor(new ThreadFactory {
|
||||
override def newThread(r: Runnable): Thread = {
|
||||
val t = new UninterruptibleThread("Kafka Offset Reader") {
|
||||
override def run(): Unit = {
|
||||
r.run()
|
||||
}
|
||||
val kafkaReaderThread = Executors.newSingleThreadExecutor((r: Runnable) => {
|
||||
val t = new UninterruptibleThread("Kafka Offset Reader") {
|
||||
override def run(): Unit = {
|
||||
r.run()
|
||||
}
|
||||
t.setDaemon(true)
|
||||
t
|
||||
}
|
||||
t.setDaemon(true)
|
||||
t
|
||||
})
|
||||
val execContext = ExecutionContext.fromExecutorService(kafkaReaderThread)
|
||||
|
||||
|
|
|
@ -361,9 +361,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
|
||||
override def capabilities(): ju.Set[TableCapability] = Collections.emptySet()
|
||||
|
||||
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder {
|
||||
override def build(): Scan = new KafkaScan(options)
|
||||
}
|
||||
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
|
||||
() => new KafkaScan(options)
|
||||
|
||||
override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
|
||||
new WriteBuilder {
|
||||
|
|
|
@ -53,7 +53,7 @@ class DirectKafkaStreamSuite
|
|||
.setMaster("local[4]")
|
||||
.setAppName(this.getClass.getSimpleName)
|
||||
// Set a timeout of 10 seconds that's going to be used to fetch topics/partitions from kafka.
|
||||
// Othewise the poll timeout defaults to 2 minutes and causes test cases to run longer.
|
||||
// Otherwise the poll timeout defaults to 2 minutes and causes test cases to run longer.
|
||||
.set("spark.streaming.kafka.consumer.poll.ms", "10000")
|
||||
|
||||
private var ssc: StreamingContext = _
|
||||
|
@ -61,13 +61,13 @@ class DirectKafkaStreamSuite
|
|||
|
||||
private var kafkaTestUtils: KafkaTestUtils = _
|
||||
|
||||
override def beforeAll {
|
||||
override def beforeAll() {
|
||||
super.beforeAll()
|
||||
kafkaTestUtils = new KafkaTestUtils
|
||||
kafkaTestUtils.setup()
|
||||
}
|
||||
|
||||
override def afterAll {
|
||||
override def afterAll() {
|
||||
try {
|
||||
if (kafkaTestUtils != null) {
|
||||
kafkaTestUtils.teardown()
|
||||
|
@ -454,13 +454,11 @@ class DirectKafkaStreamSuite
|
|||
val data = rdd.map(_.value).collect()
|
||||
collectedData.addAll(Arrays.asList(data: _*))
|
||||
kafkaStream.asInstanceOf[CanCommitOffsets]
|
||||
.commitAsync(offsets, new OffsetCommitCallback() {
|
||||
def onComplete(m: JMap[TopicPartition, OffsetAndMetadata], e: Exception) {
|
||||
if (null != e) {
|
||||
logError("commit failed", e)
|
||||
} else {
|
||||
committed.putAll(m)
|
||||
}
|
||||
.commitAsync(offsets, (m: JMap[TopicPartition, OffsetAndMetadata], e: Exception) => {
|
||||
if (null != e) {
|
||||
logError("commit failed", e)
|
||||
} else {
|
||||
committed.putAll(m)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -27,7 +27,6 @@ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorC
|
|||
import org.mockito.ArgumentMatchers._
|
||||
import org.mockito.Mockito._
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
|
||||
import org.scalatest.concurrent.Eventually
|
||||
import org.scalatest.mockito.MockitoSugar
|
||||
|
@ -124,11 +123,9 @@ class KinesisCheckpointerSuite extends TestSuiteBase
|
|||
test("if checkpointing is going on, wait until finished before removing and checkpointing") {
|
||||
when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
|
||||
.thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
|
||||
when(checkpointerMock.checkpoint(anyString)).thenAnswer(new Answer[Unit] {
|
||||
override def answer(invocations: InvocationOnMock): Unit = {
|
||||
clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2)
|
||||
}
|
||||
})
|
||||
when(checkpointerMock.checkpoint(anyString)).thenAnswer { (_: InvocationOnMock) =>
|
||||
clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2)
|
||||
}
|
||||
|
||||
kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
|
||||
clock.advance(checkpointInterval.milliseconds)
|
||||
|
|
|
@ -130,16 +130,14 @@ private[impl] case class EdgeWithLocalIds[@specialized ED](
|
|||
|
||||
private[impl] object EdgeWithLocalIds {
|
||||
implicit def lexicographicOrdering[ED]: Ordering[EdgeWithLocalIds[ED]] =
|
||||
new Ordering[EdgeWithLocalIds[ED]] {
|
||||
override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = {
|
||||
if (a.srcId == b.srcId) {
|
||||
if (a.dstId == b.dstId) 0
|
||||
else if (a.dstId < b.dstId) -1
|
||||
else 1
|
||||
} else if (a.srcId < b.srcId) -1
|
||||
(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]) =>
|
||||
if (a.srcId == b.srcId) {
|
||||
if (a.dstId == b.dstId) 0
|
||||
else if (a.dstId < b.dstId) -1
|
||||
else 1
|
||||
}
|
||||
}
|
||||
else if (a.srcId < b.srcId) -1
|
||||
else 1
|
||||
|
||||
private[graphx] def edgeArraySortDataFormat[ED] = {
|
||||
new SortDataFormat[EdgeWithLocalIds[ED], Array[EdgeWithLocalIds[ED]]] {
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.repl
|
|||
|
||||
import java.io.File
|
||||
import java.net.{URI, URL, URLClassLoader}
|
||||
import java.nio.channels.{FileChannel, ReadableByteChannel}
|
||||
import java.nio.channels.FileChannel
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.nio.file.{Paths, StandardOpenOption}
|
||||
import java.util
|
||||
|
@ -33,7 +33,6 @@ import com.google.common.io.Files
|
|||
import org.mockito.ArgumentMatchers.anyString
|
||||
import org.mockito.Mockito._
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.mockito.MockitoSugar
|
||||
|
||||
|
@ -191,12 +190,10 @@ class ExecutorClassLoaderSuite
|
|||
val env = mock[SparkEnv]
|
||||
val rpcEnv = mock[RpcEnv]
|
||||
when(env.rpcEnv).thenReturn(rpcEnv)
|
||||
when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() {
|
||||
override def answer(invocation: InvocationOnMock): ReadableByteChannel = {
|
||||
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
|
||||
val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/"))
|
||||
FileChannel.open(path, StandardOpenOption.READ)
|
||||
}
|
||||
when(rpcEnv.openChannel(anyString())).thenAnswer((invocation: InvocationOnMock) => {
|
||||
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
|
||||
val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/"))
|
||||
FileChannel.open(path, StandardOpenOption.READ)
|
||||
})
|
||||
|
||||
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",
|
||||
|
|
|
@ -52,9 +52,7 @@ class SingletonReplSuite extends SparkFunSuite {
|
|||
Main.sparkSession = null
|
||||
|
||||
// Starts a new thread to run the REPL interpreter, so that we won't block.
|
||||
thread = new Thread(new Runnable {
|
||||
override def run(): Unit = Main.doMain(Array("-classpath", classpath), interp)
|
||||
})
|
||||
thread = new Thread(() => Main.doMain(Array("-classpath", classpath), interp))
|
||||
thread.setDaemon(true)
|
||||
thread.start()
|
||||
|
||||
|
|
|
@ -47,9 +47,7 @@ private[k8s] class LoggingPodStatusWatcherImpl(
|
|||
// start timer for periodic logging
|
||||
private val scheduler =
|
||||
ThreadUtils.newDaemonSingleThreadScheduledExecutor("logging-pod-status-watcher")
|
||||
private val logRunnable: Runnable = new Runnable {
|
||||
override def run() = logShortStatus()
|
||||
}
|
||||
private val logRunnable: Runnable = () => logShortStatus()
|
||||
|
||||
private var pod = Option.empty[Pod]
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: Schedul
|
|||
}
|
||||
subscribers += newSubscriber
|
||||
pollingTasks += subscribersExecutor.scheduleWithFixedDelay(
|
||||
toRunnable(() => callSubscriber(newSubscriber)),
|
||||
() => callSubscriber(newSubscriber),
|
||||
0L,
|
||||
processBatchIntervalMillis,
|
||||
TimeUnit.MILLISECONDS)
|
||||
|
@ -103,10 +103,6 @@ private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: Schedul
|
|||
}
|
||||
}
|
||||
|
||||
private def toRunnable[T](runnable: () => Unit): Runnable = new Runnable {
|
||||
override def run(): Unit = runnable()
|
||||
}
|
||||
|
||||
private case class SnapshotsSubscriber(
|
||||
snapshotsBuffer: BlockingQueue[ExecutorPodsSnapshot],
|
||||
onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit)
|
||||
|
|
|
@ -23,7 +23,6 @@ import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, Secr
|
|||
import org.mockito.ArgumentMatchers.any
|
||||
import org.mockito.Mockito.{mock, when}
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
|
||||
import org.apache.spark.deploy.k8s.SparkPod
|
||||
|
||||
|
@ -38,16 +37,14 @@ object KubernetesFeaturesTestUtils {
|
|||
when(mockStep.getAdditionalPodSystemProperties())
|
||||
.thenReturn(Map(stepType -> stepType))
|
||||
when(mockStep.configurePod(any(classOf[SparkPod])))
|
||||
.thenAnswer(new Answer[SparkPod]() {
|
||||
override def answer(invocation: InvocationOnMock): SparkPod = {
|
||||
val originalPod: SparkPod = invocation.getArgument(0)
|
||||
val configuredPod = new PodBuilder(originalPod.pod)
|
||||
.editOrNewMetadata()
|
||||
.addToLabels(stepType, stepType)
|
||||
.endMetadata()
|
||||
.build()
|
||||
SparkPod(configuredPod, originalPod.container)
|
||||
}
|
||||
.thenAnswer((invocation: InvocationOnMock) => {
|
||||
val originalPod: SparkPod = invocation.getArgument(0)
|
||||
val configuredPod = new PodBuilder(originalPod.pod)
|
||||
.editOrNewMetadata()
|
||||
.addToLabels(stepType, stepType)
|
||||
.endMetadata()
|
||||
.build()
|
||||
SparkPod(configuredPod, originalPod.container)
|
||||
})
|
||||
mockStep
|
||||
}
|
||||
|
@ -67,6 +64,6 @@ object KubernetesFeaturesTestUtils {
|
|||
|
||||
def filter[T: ClassTag](list: Seq[HasMetadata]): Seq[T] = {
|
||||
val desired = implicitly[ClassTag[T]].runtimeClass
|
||||
list.filter(_.getClass() == desired).map(_.asInstanceOf[T]).toSeq
|
||||
list.filter(_.getClass() == desired).map(_.asInstanceOf[T])
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster.k8s
|
|||
import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder}
|
||||
import io.fabric8.kubernetes.client.KubernetesClient
|
||||
import io.fabric8.kubernetes.client.dsl.PodResource
|
||||
import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations}
|
||||
import org.mockito.{Mock, MockitoAnnotations}
|
||||
import org.mockito.ArgumentMatchers.{any, eq => meq}
|
||||
import org.mockito.Mockito.{never, times, verify, when}
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
|
@ -27,7 +27,7 @@ import org.mockito.stubbing.Answer
|
|||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
|
||||
import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, SparkPod}
|
||||
import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod}
|
||||
import org.apache.spark.deploy.k8s.Config._
|
||||
import org.apache.spark.deploy.k8s.Constants._
|
||||
import org.apache.spark.deploy.k8s.Fabric8Aliases._
|
||||
|
@ -153,12 +153,9 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
verify(podOperations).create(podWithAttachedContainerForId(2))
|
||||
}
|
||||
|
||||
private def executorPodAnswer(): Answer[SparkPod] = {
|
||||
new Answer[SparkPod] {
|
||||
override def answer(invocation: InvocationOnMock): SparkPod = {
|
||||
val k8sConf: KubernetesExecutorConf = invocation.getArgument(0)
|
||||
executorPodWithId(k8sConf.executorId.toInt)
|
||||
}
|
||||
}
|
||||
private def executorPodAnswer(): Answer[SparkPod] =
|
||||
(invocation: InvocationOnMock) => {
|
||||
val k8sConf: KubernetesExecutorConf = invocation.getArgument(0)
|
||||
executorPodWithId(k8sConf.executorId.toInt)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.mockito.Mockito.{mock, never, times, verify, when}
|
|||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.BeforeAndAfter
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkFunSuite}
|
||||
|
@ -125,13 +124,10 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte
|
|||
""".stripMargin
|
||||
}
|
||||
|
||||
private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = {
|
||||
new Answer[PodResource[Pod, DoneablePod]] {
|
||||
override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = {
|
||||
val podName: String = invocation.getArgument(0)
|
||||
namedExecutorPods.getOrElseUpdate(
|
||||
podName, mock(classOf[PodResource[Pod, DoneablePod]]))
|
||||
}
|
||||
private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] =
|
||||
(invocation: InvocationOnMock) => {
|
||||
val podName: String = invocation.getArgument(0)
|
||||
namedExecutorPods.getOrElseUpdate(
|
||||
podName, mock(classOf[PodResource[Pod, DoneablePod]]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -192,11 +192,9 @@ private[yarn] class YarnAllocator(
|
|||
* A sequence of pending container requests at the given location that have not yet been
|
||||
* fulfilled.
|
||||
*/
|
||||
private def getPendingAtLocation(location: String): Seq[ContainerRequest] = {
|
||||
private def getPendingAtLocation(location: String): Seq[ContainerRequest] =
|
||||
amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).asScala
|
||||
.flatMap(_.asScala)
|
||||
.toSeq
|
||||
}
|
||||
|
||||
/**
|
||||
* Request as many executors from the ResourceManager as needed to reach the desired total. If
|
||||
|
@ -384,7 +382,7 @@ private[yarn] class YarnAllocator(
|
|||
def stop(): Unit = {
|
||||
// Forcefully shut down the launcher pool, in case this is being called in the middle of
|
||||
// container allocation. This will prevent queued executors from being started - and
|
||||
// potentially interrupt active ExecutorRunnable instaces too.
|
||||
// potentially interrupt active ExecutorRunnable instances too.
|
||||
launcherPool.shutdownNow()
|
||||
}
|
||||
|
||||
|
@ -467,7 +465,7 @@ private[yarn] class YarnAllocator(
|
|||
remainingAfterOffRackMatches)
|
||||
}
|
||||
|
||||
if (!remainingAfterOffRackMatches.isEmpty) {
|
||||
if (remainingAfterOffRackMatches.nonEmpty) {
|
||||
logDebug(s"Releasing ${remainingAfterOffRackMatches.size} unneeded containers that were " +
|
||||
s"allocated to us")
|
||||
for (container <- remainingAfterOffRackMatches) {
|
||||
|
@ -550,35 +548,33 @@ private[yarn] class YarnAllocator(
|
|||
if (runningExecutors.size() < targetNumExecutors) {
|
||||
numExecutorsStarting.incrementAndGet()
|
||||
if (launchContainers) {
|
||||
launcherPool.execute(new Runnable {
|
||||
override def run(): Unit = {
|
||||
try {
|
||||
new ExecutorRunnable(
|
||||
Some(container),
|
||||
conf,
|
||||
sparkConf,
|
||||
driverUrl,
|
||||
executorId,
|
||||
executorHostname,
|
||||
executorMemory,
|
||||
executorCores,
|
||||
appAttemptId.getApplicationId.toString,
|
||||
securityMgr,
|
||||
localResources
|
||||
).run()
|
||||
updateInternalState()
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
numExecutorsStarting.decrementAndGet()
|
||||
if (NonFatal(e)) {
|
||||
logError(s"Failed to launch executor $executorId on container $containerId", e)
|
||||
// Assigned container should be released immediately
|
||||
// to avoid unnecessary resource occupation.
|
||||
amClient.releaseAssignedContainer(containerId)
|
||||
} else {
|
||||
throw e
|
||||
}
|
||||
}
|
||||
launcherPool.execute(() => {
|
||||
try {
|
||||
new ExecutorRunnable(
|
||||
Some(container),
|
||||
conf,
|
||||
sparkConf,
|
||||
driverUrl,
|
||||
executorId,
|
||||
executorHostname,
|
||||
executorMemory,
|
||||
executorCores,
|
||||
appAttemptId.getApplicationId.toString,
|
||||
securityMgr,
|
||||
localResources
|
||||
).run()
|
||||
updateInternalState()
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
numExecutorsStarting.decrementAndGet()
|
||||
if (NonFatal(e)) {
|
||||
logError(s"Failed to launch executor $executorId on container $containerId", e)
|
||||
// Assigned container should be released immediately
|
||||
// to avoid unnecessary resource occupation.
|
||||
amClient.releaseAssignedContainer(containerId)
|
||||
} else {
|
||||
throw e
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
|
@ -776,7 +772,7 @@ private[yarn] class YarnAllocator(
|
|||
}
|
||||
}
|
||||
|
||||
(localityMatched.toSeq, localityUnMatched.toSeq, localityFree.toSeq)
|
||||
(localityMatched, localityUnMatched, localityFree)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.util._
|
|||
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.Platform
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
|
||||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
||||
|
@ -273,7 +272,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
|
|||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
if (children.length == 0) {
|
||||
if (children.isEmpty) {
|
||||
emptyInputGenCode(ev)
|
||||
} else {
|
||||
nonEmptyInputGenCode(ctx, ev)
|
||||
|
@ -718,17 +717,15 @@ trait ArraySortLike extends ExpectsInputTypes {
|
|||
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
|
||||
}
|
||||
|
||||
new Comparator[Any]() {
|
||||
override def compare(o1: Any, o2: Any): Int = {
|
||||
if (o1 == null && o2 == null) {
|
||||
0
|
||||
} else if (o1 == null) {
|
||||
nullOrder
|
||||
} else if (o2 == null) {
|
||||
-nullOrder
|
||||
} else {
|
||||
ordering.compare(o1, o2)
|
||||
}
|
||||
(o1: Any, o2: Any) => {
|
||||
if (o1 == null && o2 == null) {
|
||||
0
|
||||
} else if (o1 == null) {
|
||||
nullOrder
|
||||
} else if (o2 == null) {
|
||||
-nullOrder
|
||||
} else {
|
||||
ordering.compare(o1, o2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -740,17 +737,15 @@ trait ArraySortLike extends ExpectsInputTypes {
|
|||
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
|
||||
}
|
||||
|
||||
new Comparator[Any]() {
|
||||
override def compare(o1: Any, o2: Any): Int = {
|
||||
if (o1 == null && o2 == null) {
|
||||
0
|
||||
} else if (o1 == null) {
|
||||
-nullOrder
|
||||
} else if (o2 == null) {
|
||||
nullOrder
|
||||
} else {
|
||||
ordering.compare(o2, o1)
|
||||
}
|
||||
(o1: Any, o2: Any) => {
|
||||
if (o1 == null && o2 == null) {
|
||||
0
|
||||
} else if (o1 == null) {
|
||||
-nullOrder
|
||||
} else if (o2 == null) {
|
||||
nullOrder
|
||||
} else {
|
||||
ordering.compare(o2, o1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -769,7 +764,6 @@ trait ArraySortLike extends ExpectsInputTypes {
|
|||
}
|
||||
|
||||
def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = {
|
||||
val arrayData = classOf[ArrayData].getName
|
||||
val genericArrayData = classOf[GenericArrayData].getName
|
||||
val unsafeArrayData = classOf[UnsafeArrayData].getName
|
||||
val array = ctx.freshName("array")
|
||||
|
@ -2784,7 +2778,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
|
|||
} else {
|
||||
if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||
throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
|
||||
s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|
||||
s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
|
||||
}
|
||||
val element = left.eval(input)
|
||||
new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))
|
||||
|
|
|
@ -35,12 +35,8 @@ package object util extends Logging {
|
|||
val origErr = System.err
|
||||
val origOut = System.out
|
||||
try {
|
||||
System.setErr(new PrintStream(new OutputStream {
|
||||
def write(b: Int) = {}
|
||||
}))
|
||||
System.setOut(new PrintStream(new OutputStream {
|
||||
def write(b: Int) = {}
|
||||
}))
|
||||
System.setErr(new PrintStream((_: Int) => {}))
|
||||
System.setOut(new PrintStream((_: Int) => {}))
|
||||
|
||||
f
|
||||
} finally {
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
package org.apache.spark.sql.types
|
||||
|
||||
import scala.math.Ordering
|
||||
import scala.reflect.runtime.universe.typeTag
|
||||
|
||||
import org.apache.spark.annotation.Stable
|
||||
|
@ -37,11 +36,8 @@ class BinaryType private() extends AtomicType {
|
|||
|
||||
@transient private[sql] lazy val tag = typeTag[InternalType]
|
||||
|
||||
private[sql] val ordering = new Ordering[InternalType] {
|
||||
def compare(x: Array[Byte], y: Array[Byte]): Int = {
|
||||
TypeUtils.compareBinary(x, y)
|
||||
}
|
||||
}
|
||||
private[sql] val ordering =
|
||||
(x: Array[Byte], y: Array[Byte]) => TypeUtils.compareBinary(x, y)
|
||||
|
||||
/**
|
||||
* The default size of a value of the BinaryType is 100 bytes.
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.types
|
||||
|
||||
import scala.math.{Fractional, Numeric, Ordering}
|
||||
import scala.math.{Fractional, Numeric}
|
||||
import scala.math.Numeric.DoubleAsIfIntegral
|
||||
import scala.reflect.runtime.universe.typeTag
|
||||
|
||||
|
@ -38,9 +38,8 @@ class DoubleType private() extends FractionalType {
|
|||
@transient private[sql] lazy val tag = typeTag[InternalType]
|
||||
private[sql] val numeric = implicitly[Numeric[Double]]
|
||||
private[sql] val fractional = implicitly[Fractional[Double]]
|
||||
private[sql] val ordering = new Ordering[Double] {
|
||||
override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y)
|
||||
}
|
||||
private[sql] val ordering =
|
||||
(x: Double, y: Double) => Utils.nanSafeCompareDoubles(x, y)
|
||||
private[sql] val asIntegral = DoubleAsIfIntegral
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.types
|
||||
|
||||
import scala.math.{Fractional, Numeric, Ordering}
|
||||
import scala.math.{Fractional, Numeric}
|
||||
import scala.math.Numeric.FloatAsIfIntegral
|
||||
import scala.reflect.runtime.universe.typeTag
|
||||
|
||||
|
@ -38,9 +38,8 @@ class FloatType private() extends FractionalType {
|
|||
@transient private[sql] lazy val tag = typeTag[InternalType]
|
||||
private[sql] val numeric = implicitly[Numeric[Float]]
|
||||
private[sql] val fractional = implicitly[Fractional[Float]]
|
||||
private[sql] val ordering = new Ordering[Float] {
|
||||
override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y)
|
||||
}
|
||||
private[sql] val ordering =
|
||||
(x: Float, y: Float) => Utils.nanSafeCompareFloats(x, y)
|
||||
private[sql] val asIntegral = FloatAsIfIntegral
|
||||
|
||||
/**
|
||||
|
|
|
@ -38,11 +38,7 @@ class ExternalCatalogEventSuite extends SparkFunSuite {
|
|||
f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) {
|
||||
val catalog = new ExternalCatalogWithListener(newCatalog)
|
||||
val recorder = mutable.Buffer.empty[ExternalCatalogEvent]
|
||||
catalog.addListener(new ExternalCatalogEventListener {
|
||||
override def onEvent(event: ExternalCatalogEvent): Unit = {
|
||||
recorder += event
|
||||
}
|
||||
})
|
||||
catalog.addListener((event: ExternalCatalogEvent) => recorder += event)
|
||||
f(catalog, (expected: Seq[ExternalCatalogEvent]) => {
|
||||
val actual = recorder.clone()
|
||||
recorder.clear()
|
||||
|
@ -174,9 +170,6 @@ class ExternalCatalogEventSuite extends SparkFunSuite {
|
|||
className = "",
|
||||
resources = Seq.empty)
|
||||
|
||||
val newIdentifier = functionDefinition.identifier.copy(funcName = "fn4")
|
||||
val renamedFunctionDefinition = functionDefinition.copy(identifier = newIdentifier)
|
||||
|
||||
catalog.createDatabase(dbDefinition, ignoreIfExists = false)
|
||||
checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil)
|
||||
|
||||
|
|
|
@ -112,9 +112,7 @@ object SortPrefixUtils {
|
|||
val field = schema.head
|
||||
getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
|
||||
} else {
|
||||
new PrefixComparator {
|
||||
override def compare(prefix1: Long, prefix2: Long): Int = 0
|
||||
}
|
||||
(_: Long, _: Long) => 0
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -164,12 +162,7 @@ object SortPrefixUtils {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
new UnsafeExternalRowSorter.PrefixComputer {
|
||||
override def computePrefix(row: InternalRow):
|
||||
UnsafeExternalRowSorter.PrefixComputer.Prefix = {
|
||||
emptyPrefix
|
||||
}
|
||||
}
|
||||
_: InternalRow => emptyPrefix
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
|
|||
|
||||
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
|
||||
import org.apache.spark.sql.catalyst.analysis.Resolver
|
||||
import org.apache.spark.sql.catalyst.catalog._
|
||||
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
||||
|
@ -72,7 +72,7 @@ case class CreateDatabaseCommand(
|
|||
CatalogDatabase(
|
||||
databaseName,
|
||||
comment.getOrElse(""),
|
||||
path.map(CatalogUtils.stringToURI(_)).getOrElse(catalog.getDefaultDBPath(databaseName)),
|
||||
path.map(CatalogUtils.stringToURI).getOrElse(catalog.getDefaultDBPath(databaseName)),
|
||||
props),
|
||||
ifNotExists)
|
||||
Seq.empty[Row]
|
||||
|
@ -352,9 +352,8 @@ case class AlterTableChangeColumnCommand(
|
|||
}
|
||||
|
||||
// Add the comment to a column, if comment is empty, return the original column.
|
||||
private def addComment(column: StructField, comment: Option[String]): StructField = {
|
||||
comment.map(column.withComment(_)).getOrElse(column)
|
||||
}
|
||||
private def addComment(column: StructField, comment: Option[String]): StructField =
|
||||
comment.map(column.withComment).getOrElse(column)
|
||||
|
||||
// Compare a [[StructField]] to another, return true if they have the same column
|
||||
// name(by resolver) and dataType.
|
||||
|
@ -584,14 +583,12 @@ case class AlterTableRecoverPartitionsCommand(
|
|||
// It's very expensive to create a JobConf(ClassUtil.findContainingJar() is slow)
|
||||
val jobConf = new JobConf(hadoopConf, this.getClass)
|
||||
val pathFilter = FileInputFormat.getInputPathFilter(jobConf)
|
||||
new PathFilter {
|
||||
override def accept(path: Path): Boolean = {
|
||||
val name = path.getName
|
||||
if (name != "_SUCCESS" && name != "_temporary" && !name.startsWith(".")) {
|
||||
pathFilter == null || pathFilter.accept(path)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
path: Path => {
|
||||
val name = path.getName
|
||||
if (name != "_SUCCESS" && name != "_temporary" && !name.startsWith(".")) {
|
||||
pathFilter == null || pathFilter.accept(path)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.spark.sql.execution.datasources
|
||||
|
||||
import java.util.Locale
|
||||
import java.util.concurrent.Callable
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
|
@ -222,23 +221,20 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
|
|||
private def readDataSourceTable(table: CatalogTable): LogicalPlan = {
|
||||
val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table)
|
||||
val catalog = sparkSession.sessionState.catalog
|
||||
catalog.getCachedPlan(qualifiedTableName, new Callable[LogicalPlan]() {
|
||||
override def call(): LogicalPlan = {
|
||||
val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_))
|
||||
val dataSource =
|
||||
DataSource(
|
||||
sparkSession,
|
||||
// In older version(prior to 2.1) of Spark, the table schema can be empty and should be
|
||||
// inferred at runtime. We should still support it.
|
||||
userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema),
|
||||
partitionColumns = table.partitionColumnNames,
|
||||
bucketSpec = table.bucketSpec,
|
||||
className = table.provider.get,
|
||||
options = table.storage.properties ++ pathOption,
|
||||
catalogTable = Some(table))
|
||||
|
||||
LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table)
|
||||
}
|
||||
catalog.getCachedPlan(qualifiedTableName, () => {
|
||||
val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_))
|
||||
val dataSource =
|
||||
DataSource(
|
||||
sparkSession,
|
||||
// In older version(prior to 2.1) of Spark, the table schema can be empty and should be
|
||||
// inferred at runtime. We should still support it.
|
||||
userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema),
|
||||
partitionColumns = table.partitionColumnNames,
|
||||
bucketSpec = table.bucketSpec,
|
||||
className = table.provider.get,
|
||||
options = table.storage.properties ++ pathOption,
|
||||
catalogTable = Some(table))
|
||||
LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -484,8 +480,8 @@ object DataSourceStrategy {
|
|||
// Because we only convert In to InSet in Optimizer when there are more than certain
|
||||
// items. So it is possible we still get an In expression here that needs to be pushed
|
||||
// down.
|
||||
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
|
||||
val hSet = list.map(e => e.eval(EmptyRow))
|
||||
case expressions.In(a: Attribute, list) if list.forall(_.isInstanceOf[Literal]) =>
|
||||
val hSet = list.map(_.eval(EmptyRow))
|
||||
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
|
||||
Some(sources.In(a.name, hSet.toArray.map(toScala)))
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ trait CheckpointFileManager {
|
|||
|
||||
/** List all the files in a path. */
|
||||
def list(path: Path): Array[FileStatus] = {
|
||||
list(path, new PathFilter { override def accept(path: Path): Boolean = true })
|
||||
list(path, (_: Path) => true)
|
||||
}
|
||||
|
||||
/** Make directory at the give path and all its parent directories as needed. */
|
||||
|
@ -103,7 +103,7 @@ object CheckpointFileManager extends Logging {
|
|||
* @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to
|
||||
* overwrite the file if it already exists. It should not throw
|
||||
* any exception if the file exists. However, if false, then the
|
||||
* implementation must not overwrite if the file alraedy exists and
|
||||
* implementation must not overwrite if the file already exists and
|
||||
* must throw `FileAlreadyExistsException` in that case.
|
||||
*/
|
||||
def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit
|
||||
|
@ -236,14 +236,12 @@ class FileSystemBasedCheckpointFileManager(path: Path, hadoopConf: Configuration
|
|||
fs.open(path)
|
||||
}
|
||||
|
||||
override def exists(path: Path): Boolean = {
|
||||
try
|
||||
return fs.getFileStatus(path) != null
|
||||
catch {
|
||||
case e: FileNotFoundException =>
|
||||
return false
|
||||
override def exists(path: Path): Boolean =
|
||||
try {
|
||||
fs.getFileStatus(path) != null
|
||||
} catch {
|
||||
case _: FileNotFoundException => false
|
||||
}
|
||||
}
|
||||
|
||||
override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = {
|
||||
if (!overwriteIfPossible && fs.exists(dstPath)) {
|
||||
|
|
|
@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets.UTF_8
|
|||
import scala.io.{Source => IOSource}
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.hadoop.fs.{Path, PathFilter}
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s.NoTypeHints
|
||||
import org.json4s.jackson.Serialization
|
||||
|
||||
|
@ -169,13 +169,13 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag](
|
|||
*/
|
||||
private def compact(batchId: Long, logs: Array[T]): Boolean = {
|
||||
val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval)
|
||||
val allLogs = validBatches.map { id =>
|
||||
val allLogs = validBatches.flatMap { id =>
|
||||
super.get(id).getOrElse {
|
||||
throw new IllegalStateException(
|
||||
s"${batchIdToPath(id)} doesn't exist when compacting batch $batchId " +
|
||||
s"(compactInterval: $compactInterval)")
|
||||
}
|
||||
}.flatten ++ logs
|
||||
} ++ logs
|
||||
// Return false as there is another writer.
|
||||
super.add(batchId, compactLogs(allLogs).toArray)
|
||||
}
|
||||
|
@ -192,13 +192,13 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag](
|
|||
if (latestId >= 0) {
|
||||
try {
|
||||
val logs =
|
||||
getAllValidBatches(latestId, compactInterval).map { id =>
|
||||
getAllValidBatches(latestId, compactInterval).flatMap { id =>
|
||||
super.get(id).getOrElse {
|
||||
throw new IllegalStateException(
|
||||
s"${batchIdToPath(id)} doesn't exist " +
|
||||
s"(latestId: $latestId, compactInterval: $compactInterval)")
|
||||
}
|
||||
}.flatten
|
||||
}
|
||||
return compactLogs(logs).toArray
|
||||
} catch {
|
||||
case e: IOException =>
|
||||
|
@ -240,15 +240,13 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag](
|
|||
s"min compaction batch id to delete = $minCompactionBatchId")
|
||||
|
||||
val expiredTime = System.currentTimeMillis() - fileCleanupDelayMs
|
||||
fileManager.list(metadataPath, new PathFilter {
|
||||
override def accept(path: Path): Boolean = {
|
||||
try {
|
||||
val batchId = getBatchIdFromFileName(path.getName)
|
||||
batchId < minCompactionBatchId
|
||||
} catch {
|
||||
case _: NumberFormatException =>
|
||||
false
|
||||
}
|
||||
fileManager.list(metadataPath, (path: Path) => {
|
||||
try {
|
||||
val batchId = getBatchIdFromFileName(path.getName)
|
||||
batchId < minCompactionBatchId
|
||||
} catch {
|
||||
case _: NumberFormatException =>
|
||||
false
|
||||
}
|
||||
}).foreach { f =>
|
||||
if (f.getModificationTime <= expiredTime) {
|
||||
|
|
|
@ -89,19 +89,15 @@ class RateStreamTable(
|
|||
|
||||
override def capabilities(): util.Set[TableCapability] = Collections.emptySet()
|
||||
|
||||
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder {
|
||||
override def build(): Scan = new Scan {
|
||||
override def readSchema(): StructType = RateStreamProvider.SCHEMA
|
||||
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan {
|
||||
override def readSchema(): StructType = RateStreamProvider.SCHEMA
|
||||
|
||||
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
|
||||
new RateStreamMicroBatchStream(
|
||||
rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation)
|
||||
}
|
||||
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream =
|
||||
new RateStreamMicroBatchStream(
|
||||
rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation)
|
||||
|
||||
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
|
||||
new RateStreamContinuousStream(rowsPerSecond, numPartitions)
|
||||
}
|
||||
}
|
||||
override def toContinuousStream(checkpointLocation: String): ContinuousStream =
|
||||
new RateStreamContinuousStream(rowsPerSecond, numPartitions)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -130,27 +130,24 @@ class TextSocketMicroBatchStream(host: String, port: Int, numPartitions: Int)
|
|||
slices.map(TextSocketInputPartition)
|
||||
}
|
||||
|
||||
override def createReaderFactory(): PartitionReaderFactory = {
|
||||
new PartitionReaderFactory {
|
||||
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
|
||||
val slice = partition.asInstanceOf[TextSocketInputPartition].slice
|
||||
new PartitionReader[InternalRow] {
|
||||
private var currentIdx = -1
|
||||
override def createReaderFactory(): PartitionReaderFactory =
|
||||
(partition: InputPartition) => {
|
||||
val slice = partition.asInstanceOf[TextSocketInputPartition].slice
|
||||
new PartitionReader[InternalRow] {
|
||||
private var currentIdx = -1
|
||||
|
||||
override def next(): Boolean = {
|
||||
currentIdx += 1
|
||||
currentIdx < slice.size
|
||||
}
|
||||
|
||||
override def get(): InternalRow = {
|
||||
InternalRow(slice(currentIdx)._1, slice(currentIdx)._2)
|
||||
}
|
||||
|
||||
override def close(): Unit = {}
|
||||
override def next(): Boolean = {
|
||||
currentIdx += 1
|
||||
currentIdx < slice.size
|
||||
}
|
||||
|
||||
override def get(): InternalRow = {
|
||||
InternalRow(slice(currentIdx)._1, slice(currentIdx)._2)
|
||||
}
|
||||
|
||||
override def close(): Unit = {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def commit(end: Offset): Unit = synchronized {
|
||||
val newOffset = LongOffset.convert(end).getOrElse(
|
||||
|
|
|
@ -81,17 +81,15 @@ class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimest
|
|||
|
||||
override def capabilities(): util.Set[TableCapability] = Collections.emptySet()
|
||||
|
||||
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder {
|
||||
override def build(): Scan = new Scan {
|
||||
override def readSchema(): StructType = schema()
|
||||
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan {
|
||||
override def readSchema(): StructType = schema()
|
||||
|
||||
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
|
||||
new TextSocketMicroBatchStream(host, port, numPartitions)
|
||||
}
|
||||
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
|
||||
new TextSocketMicroBatchStream(host, port, numPartitions)
|
||||
}
|
||||
|
||||
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
|
||||
new TextSocketContinuousStream(host, port, numPartitions, options)
|
||||
}
|
||||
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
|
||||
new TextSocketContinuousStream(host, port, numPartitions, options)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.ui
|
|||
|
||||
import java.util.{Date, NoSuchElementException}
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.function.Function
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
|
@ -196,7 +195,7 @@ class SQLAppStatusListener(
|
|||
|
||||
// Check the execution again for whether the aggregated metrics data has been calculated.
|
||||
// This can happen if the UI is requesting this data, and the onExecutionEnd handler is
|
||||
// running at the same time. The metrics calculated for the UI can be innacurate in that
|
||||
// running at the same time. The metrics calculated for the UI can be inaccurate in that
|
||||
// case, since the onExecutionEnd handler will clean up tracked stage metrics.
|
||||
if (exec.metricsValues != null) {
|
||||
exec.metricsValues
|
||||
|
@ -328,9 +327,7 @@ class SQLAppStatusListener(
|
|||
|
||||
private def getOrCreateExecution(executionId: Long): LiveExecutionData = {
|
||||
liveExecutions.computeIfAbsent(executionId,
|
||||
new Function[Long, LiveExecutionData]() {
|
||||
override def apply(key: Long): LiveExecutionData = new LiveExecutionData(executionId)
|
||||
})
|
||||
(_: Long) => new LiveExecutionData(executionId))
|
||||
}
|
||||
|
||||
private def update(exec: LiveExecutionData, force: Boolean = false): Unit = {
|
||||
|
|
|
@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.CacheManager
|
|||
import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab}
|
||||
import org.apache.spark.sql.internal.StaticSQLConf._
|
||||
import org.apache.spark.status.ElementTrackingStore
|
||||
import org.apache.spark.util.{MutableURLClassLoader, Utils}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
||||
/**
|
||||
|
@ -146,11 +146,7 @@ private[sql] class SharedState(
|
|||
val wrapped = new ExternalCatalogWithListener(externalCatalog)
|
||||
|
||||
// Make sure we propagate external catalog events to the spark listener bus
|
||||
wrapped.addListener(new ExternalCatalogEventListener {
|
||||
override def onEvent(event: ExternalCatalogEvent): Unit = {
|
||||
sparkContext.listenerBus.post(event)
|
||||
}
|
||||
})
|
||||
wrapped.addListener((event: ExternalCatalogEvent) => sparkContext.listenerBus.post(event))
|
||||
|
||||
wrapped
|
||||
}
|
||||
|
|
|
@ -1195,11 +1195,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
GroupedRoutes("a", "c", Seq(Route("a", "c", 2)))
|
||||
)
|
||||
|
||||
implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] = new Ordering[GroupedRoutes] {
|
||||
override def compare(x: GroupedRoutes, y: GroupedRoutes): Int = {
|
||||
x.toString.compareTo(y.toString)
|
||||
}
|
||||
}
|
||||
implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] =
|
||||
(x: GroupedRoutes, y: GroupedRoutes) => x.toString.compareTo(y.toString)
|
||||
|
||||
checkDatasetUnorderly(grped, expected: _*)
|
||||
}
|
||||
|
|
|
@ -120,10 +120,7 @@ class QueryExecutionSuite extends SharedSQLContext {
|
|||
}
|
||||
|
||||
test("toString() exception/error handling") {
|
||||
spark.experimental.extraStrategies = Seq(
|
||||
new SparkStrategy {
|
||||
override def apply(plan: LogicalPlan): Seq[SparkPlan] = Nil
|
||||
})
|
||||
spark.experimental.extraStrategies = Seq[SparkStrategy]((_: LogicalPlan) => Nil)
|
||||
|
||||
def qe: QueryExecution = new QueryExecution(spark, OneRowRelation())
|
||||
|
||||
|
@ -131,19 +128,13 @@ class QueryExecutionSuite extends SharedSQLContext {
|
|||
assert(qe.toString.contains("OneRowRelation"))
|
||||
|
||||
// Throw an AnalysisException - this should be captured.
|
||||
spark.experimental.extraStrategies = Seq(
|
||||
new SparkStrategy {
|
||||
override def apply(plan: LogicalPlan): Seq[SparkPlan] =
|
||||
throw new AnalysisException("exception")
|
||||
})
|
||||
spark.experimental.extraStrategies = Seq[SparkStrategy](
|
||||
(_: LogicalPlan) => throw new AnalysisException("exception"))
|
||||
assert(qe.toString.contains("org.apache.spark.sql.AnalysisException"))
|
||||
|
||||
// Throw an Error - this should not be captured.
|
||||
spark.experimental.extraStrategies = Seq(
|
||||
new SparkStrategy {
|
||||
override def apply(plan: LogicalPlan): Seq[SparkPlan] =
|
||||
throw new Error("error")
|
||||
})
|
||||
spark.experimental.extraStrategies = Seq[SparkStrategy](
|
||||
(_: LogicalPlan) => throw new Error("error"))
|
||||
val error = intercept[Error](qe.toString)
|
||||
assert(error.getMessage.contains("error"))
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.execution.benchmark
|
||||
|
||||
import java.util.{Arrays, Comparator}
|
||||
import java.util.Arrays
|
||||
|
||||
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
|
||||
import org.apache.spark.unsafe.array.LongArray
|
||||
|
@ -40,14 +40,9 @@ object SortBenchmark extends BenchmarkBase {
|
|||
|
||||
private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
|
||||
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
|
||||
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
|
||||
buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
|
||||
override def compare(
|
||||
r1: RecordPointerAndKeyPrefix,
|
||||
r2: RecordPointerAndKeyPrefix): Int = {
|
||||
refCmp.compare(r1.keyPrefix, r2.keyPrefix)
|
||||
}
|
||||
})
|
||||
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(buf, lo, hi,
|
||||
(r1: RecordPointerAndKeyPrefix, r2: RecordPointerAndKeyPrefix) =>
|
||||
refCmp.compare(r1.keyPrefix, r2.keyPrefix))
|
||||
}
|
||||
|
||||
private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = {
|
||||
|
|
|
@ -354,11 +354,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
val listener = new StreamingQueryListener {
|
||||
override def onQueryStarted(event: QueryStartedEvent): Unit = {
|
||||
// Note: this assumes there is only one query active in the `testStream` method.
|
||||
Thread.currentThread.setUncaughtExceptionHandler(new UncaughtExceptionHandler {
|
||||
override def uncaughtException(t: Thread, e: Throwable): Unit = {
|
||||
streamThreadDeathCause = e
|
||||
}
|
||||
})
|
||||
Thread.currentThread.setUncaughtExceptionHandler(
|
||||
(_: Thread, e: Throwable) => streamThreadDeathCause = e)
|
||||
}
|
||||
|
||||
override def onQueryProgress(event: QueryProgressEvent): Unit = {}
|
||||
|
|
|
@ -33,7 +33,7 @@ import org.apache.spark.sql._
|
|||
import org.apache.spark.sql.execution.streaming._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
|
||||
import org.apache.spark.sql.streaming.{ProcessingTime => DeprecatedProcessingTime, _}
|
||||
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamingQueryException, StreamTest}
|
||||
import org.apache.spark.sql.streaming.Trigger._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -104,9 +104,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
|
|||
LastOptions.parameters = parameters
|
||||
LastOptions.partitionColumns = partitionColumns
|
||||
LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns, outputMode)
|
||||
new Sink {
|
||||
override def addBatch(batchId: Long, data: DataFrame): Unit = {}
|
||||
}
|
||||
(_: Long, _: DataFrame) => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -60,11 +60,7 @@ class BlockingSource extends StreamSourceProvider with StreamSinkProvider {
|
|||
spark: SQLContext,
|
||||
parameters: Map[String, String],
|
||||
partitionColumns: Seq[String],
|
||||
outputMode: OutputMode): Sink = {
|
||||
new Sink {
|
||||
override def addBatch(batchId: Long, data: DataFrame): Unit = {}
|
||||
}
|
||||
}
|
||||
outputMode: OutputMode): Sink = (_: Long, _: DataFrame) => {}
|
||||
}
|
||||
|
||||
object BlockingSource {
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.commons.lang3.StringUtils
|
|||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor}
|
||||
import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils}
|
||||
import org.apache.hadoop.hive.common.HiveInterruptUtils
|
||||
import org.apache.hadoop.hive.conf.HiveConf
|
||||
import org.apache.hadoop.hive.ql.Driver
|
||||
import org.apache.hadoop.hive.ql.exec.Utilities
|
||||
|
@ -65,16 +65,14 @@ private[hive] object SparkSQLCLIDriver extends Logging {
|
|||
* a command is being processed by the current thread.
|
||||
*/
|
||||
def installSignalHandler() {
|
||||
HiveInterruptUtils.add(new HiveInterruptCallback {
|
||||
override def interrupt() {
|
||||
// Handle remote execution mode
|
||||
if (SparkSQLEnv.sparkContext != null) {
|
||||
SparkSQLEnv.sparkContext.cancelAllJobs()
|
||||
} else {
|
||||
if (transport != null) {
|
||||
// Force closing of TCP connection upon session termination
|
||||
transport.getSocket.close()
|
||||
}
|
||||
HiveInterruptUtils.add(() => {
|
||||
// Handle remote execution mode
|
||||
if (SparkSQLEnv.sparkContext != null) {
|
||||
SparkSQLEnv.sparkContext.cancelAllJobs()
|
||||
} else {
|
||||
if (transport != null) {
|
||||
// Force closing of TCP connection upon session termination
|
||||
transport.getSocket.close()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -208,7 +206,7 @@ private[hive] object SparkSQLCLIDriver extends Logging {
|
|||
reader.setBellEnabled(false)
|
||||
reader.setExpandEvents(false)
|
||||
// reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true)))
|
||||
CliDriver.getCommandCompleter.foreach((e) => reader.addCompleter(e))
|
||||
CliDriver.getCommandCompleter.foreach(reader.addCompleter)
|
||||
|
||||
val historyDirectory = System.getProperty("user.home")
|
||||
|
||||
|
|
|
@ -255,9 +255,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
|
|||
}
|
||||
}
|
||||
|
||||
def numReceivers(): Int = {
|
||||
receiverInputStreams.size
|
||||
}
|
||||
def numReceivers(): Int = receiverInputStreams.length
|
||||
|
||||
/** Register a receiver */
|
||||
private def registerReceiver(
|
||||
|
@ -516,14 +514,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
|
|||
context.reply(successful)
|
||||
case AddBlock(receivedBlockInfo) =>
|
||||
if (WriteAheadLogUtils.isBatchingEnabled(ssc.conf, isDriver = true)) {
|
||||
walBatchingThreadPool.execute(new Runnable {
|
||||
override def run(): Unit = Utils.tryLogNonFatalError {
|
||||
if (active) {
|
||||
context.reply(addBlock(receivedBlockInfo))
|
||||
} else {
|
||||
context.sendFailure(
|
||||
new IllegalStateException("ReceiverTracker RpcEndpoint already shut down."))
|
||||
}
|
||||
walBatchingThreadPool.execute(() => Utils.tryLogNonFatalError {
|
||||
if (active) {
|
||||
context.reply(addBlock(receivedBlockInfo))
|
||||
} else {
|
||||
context.sendFailure(
|
||||
new IllegalStateException("ReceiverTracker RpcEndpoint already shut down."))
|
||||
}
|
||||
})
|
||||
} else {
|
||||
|
|
|
@ -135,18 +135,16 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
|
|||
|
||||
/** Start the actual log writer on a separate thread. */
|
||||
private def startBatchedWriterThread(): Thread = {
|
||||
val thread = new Thread(new Runnable {
|
||||
override def run(): Unit = {
|
||||
while (active.get()) {
|
||||
try {
|
||||
flushRecords()
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
logWarning("Encountered exception in Batched Writer Thread.", e)
|
||||
}
|
||||
val thread = new Thread(() => {
|
||||
while (active.get()) {
|
||||
try {
|
||||
flushRecords()
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
logWarning("Encountered exception in Batched Writer Thread.", e)
|
||||
}
|
||||
logInfo("BatchedWriteAheadLog Writer thread exiting.")
|
||||
}
|
||||
logInfo("BatchedWriteAheadLog Writer thread exiting.")
|
||||
}, "BatchedWriteAheadLog Writer")
|
||||
thread.setDaemon(true)
|
||||
thread.start()
|
||||
|
|
Loading…
Reference in a new issue