[SPARK-24005][CORE] Remove usage of Scala’s parallel collection
## What changes were proposed in this pull request? In the PR, I propose to replace Scala parallel collections by new methods `parmap()`. The methods use futures to transform a sequential collection by applying a lambda function to each element in parallel. The result of `parmap` is another regular (sequential) collection. The proposed `parmap` method aims to solve the problem of impossibility to interrupt parallel Scala collection. This possibility is needed for reliable task preemption. ## How was this patch tested? A test was added to `ThreadUtilsSuite` Closes #21913 from MaxGekk/par-map. Authored-by: Maxim Gekk <maxim.gekk@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
88e0c7bbd5
commit
131ca146ed
|
@ -20,12 +20,13 @@ package org.apache.spark.rdd
|
|||
import java.io.{IOException, ObjectOutputStream}
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.parallel.ForkJoinTaskSupport
|
||||
import scala.concurrent.ExecutionContext
|
||||
import scala.concurrent.forkjoin.ForkJoinPool
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.util.ThreadUtils.parmap
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
|
@ -59,8 +60,7 @@ private[spark] class UnionPartition[T: ClassTag](
|
|||
}
|
||||
|
||||
object UnionRDD {
|
||||
private[spark] lazy val partitionEvalTaskSupport =
|
||||
new ForkJoinTaskSupport(new ForkJoinPool(8))
|
||||
private[spark] lazy val threadPool = new ForkJoinPool(8)
|
||||
}
|
||||
|
||||
@DeveloperApi
|
||||
|
@ -74,14 +74,13 @@ class UnionRDD[T: ClassTag](
|
|||
rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)
|
||||
|
||||
override def getPartitions: Array[Partition] = {
|
||||
val parRDDs = if (isPartitionListingParallel) {
|
||||
val parArray = rdds.par
|
||||
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
|
||||
parArray
|
||||
val partitionLengths = if (isPartitionListingParallel) {
|
||||
implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool)
|
||||
parmap(rdds)(_.partitions.length)
|
||||
} else {
|
||||
rdds
|
||||
rdds.map(_.partitions.length)
|
||||
}
|
||||
val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
|
||||
val array = new Array[Partition](partitionLengths.sum)
|
||||
var pos = 0
|
||||
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
|
||||
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
|
||||
|
|
|
@ -19,8 +19,12 @@ package org.apache.spark.util
|
|||
|
||||
import java.util.concurrent._
|
||||
|
||||
import scala.collection.TraversableLike
|
||||
import scala.collection.generic.CanBuildFrom
|
||||
import scala.language.higherKinds
|
||||
|
||||
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
|
||||
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor}
|
||||
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future}
|
||||
import scala.concurrent.duration.{Duration, FiniteDuration}
|
||||
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
|
||||
import scala.util.control.NonFatal
|
||||
|
@ -254,4 +258,62 @@ private[spark] object ThreadUtils {
|
|||
executor.shutdownNow()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms input collection by applying the given function to each element in parallel fashion.
|
||||
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
|
||||
* at any time. This is useful on canceling of task execution, for example.
|
||||
*
|
||||
* @param in - the input collection which should be transformed in parallel.
|
||||
* @param prefix - the prefix assigned to the underlying thread pool.
|
||||
* @param maxThreads - maximum number of thread can be created during execution.
|
||||
* @param f - the lambda function will be applied to each element of `in`.
|
||||
* @tparam I - the type of elements in the input collection.
|
||||
* @tparam O - the type of elements in resulted collection.
|
||||
* @return new collection in which each element was given from the input collection `in` by
|
||||
* applying the lambda function `f`.
|
||||
*/
|
||||
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
|
||||
(in: Col[I], prefix: String, maxThreads: Int)
|
||||
(f: I => O)
|
||||
(implicit
|
||||
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
|
||||
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]] // for Future.sequence
|
||||
): Col[O] = {
|
||||
val pool = newForkJoinPool(prefix, maxThreads)
|
||||
try {
|
||||
implicit val ec = ExecutionContext.fromExecutor(pool)
|
||||
|
||||
parmap(in)(f)
|
||||
} finally {
|
||||
pool.shutdownNow()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms input collection by applying the given function to each element in parallel fashion.
|
||||
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
|
||||
* at any time. This is useful on canceling of task execution, for example.
|
||||
*
|
||||
* @param in - the input collection which should be transformed in parallel.
|
||||
* @param f - the lambda function will be applied to each element of `in`.
|
||||
* @param ec - an execution context for parallel applying of the given function `f`.
|
||||
* @tparam I - the type of elements in the input collection.
|
||||
* @tparam O - the type of elements in resulted collection.
|
||||
* @return new collection in which each element was given from the input collection `in` by
|
||||
* applying the lambda function `f`.
|
||||
*/
|
||||
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
|
||||
(in: Col[I])
|
||||
(f: I => O)
|
||||
(implicit
|
||||
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
|
||||
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]], // for Future.sequence
|
||||
ec: ExecutionContext
|
||||
): Col[O] = {
|
||||
val futures = in.map(x => Future(f(x)))
|
||||
val futureSeq = Future.sequence(futures)
|
||||
|
||||
awaitResult(futureSeq, Duration.Inf)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -133,4 +133,37 @@ class ThreadUtilsSuite extends SparkFunSuite {
|
|||
"stack trace contains unexpected references to ThreadUtils"
|
||||
)
|
||||
}
|
||||
|
||||
test("parmap should be interruptible") {
|
||||
val t = new Thread() {
|
||||
setDaemon(true)
|
||||
|
||||
override def run() {
|
||||
try {
|
||||
// "par" is uninterruptible. The following will keep running even if the thread is
|
||||
// interrupted. We should prefer to use "ThreadUtils.parmap".
|
||||
//
|
||||
// (1 to 10).par.flatMap { i =>
|
||||
// Thread.sleep(100000)
|
||||
// 1 to i
|
||||
// }
|
||||
//
|
||||
ThreadUtils.parmap(1 to 10, "test", 2) { i =>
|
||||
Thread.sleep(100000)
|
||||
1 to i
|
||||
}.flatten
|
||||
} catch {
|
||||
case _: InterruptedException => // excepted
|
||||
}
|
||||
}
|
||||
}
|
||||
t.start()
|
||||
eventually(timeout(10.seconds)) {
|
||||
assert(t.isAlive)
|
||||
}
|
||||
t.interrupt()
|
||||
eventually(timeout(10.seconds)) {
|
||||
assert(!t.isAlive)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command
|
|||
import java.util.Locale
|
||||
|
||||
import scala.collection.{GenMap, GenSeq}
|
||||
import scala.collection.parallel.ForkJoinTaskSupport
|
||||
import scala.concurrent.ExecutionContext
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
|
|||
|
||||
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
|
||||
import org.apache.spark.sql.catalyst.analysis.Resolver
|
||||
import org.apache.spark.sql.catalyst.catalog._
|
||||
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
||||
|
@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
|
|||
import org.apache.spark.sql.internal.HiveSerDe
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
|
||||
import org.apache.spark.util.ThreadUtils.parmap
|
||||
|
||||
// Note: The definition of these commands are based on the ones described in
|
||||
// https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
|
||||
|
@ -621,8 +622,9 @@ case class AlterTableRecoverPartitionsCommand(
|
|||
val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8)
|
||||
val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] =
|
||||
try {
|
||||
implicit val ec = ExecutionContext.fromExecutor(evalPool)
|
||||
scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold,
|
||||
spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq
|
||||
spark.sessionState.conf.resolver)
|
||||
} finally {
|
||||
evalPool.shutdown()
|
||||
}
|
||||
|
@ -654,23 +656,13 @@ case class AlterTableRecoverPartitionsCommand(
|
|||
spec: TablePartitionSpec,
|
||||
partitionNames: Seq[String],
|
||||
threshold: Int,
|
||||
resolver: Resolver,
|
||||
evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = {
|
||||
resolver: Resolver)(implicit ec: ExecutionContext): Seq[(TablePartitionSpec, Path)] = {
|
||||
if (partitionNames.isEmpty) {
|
||||
return Seq(spec -> path)
|
||||
}
|
||||
|
||||
val statuses = fs.listStatus(path, filter)
|
||||
val statusPar: GenSeq[FileStatus] =
|
||||
if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) {
|
||||
// parallelize the list of partitions here, then we can have better parallelism later.
|
||||
val parArray = statuses.par
|
||||
parArray.tasksupport = evalTaskSupport
|
||||
parArray
|
||||
} else {
|
||||
statuses
|
||||
}
|
||||
statusPar.flatMap { st =>
|
||||
val statuses = fs.listStatus(path, filter).toSeq
|
||||
def handleStatus(st: FileStatus): Seq[(TablePartitionSpec, Path)] = {
|
||||
val name = st.getPath.getName
|
||||
if (st.isDirectory && name.contains("=")) {
|
||||
val ps = name.split("=", 2)
|
||||
|
@ -679,7 +671,7 @@ case class AlterTableRecoverPartitionsCommand(
|
|||
val value = ExternalCatalogUtils.unescapePathName(ps(1))
|
||||
if (resolver(columnName, partitionNames.head)) {
|
||||
scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value),
|
||||
partitionNames.drop(1), threshold, resolver, evalTaskSupport)
|
||||
partitionNames.drop(1), threshold, resolver)
|
||||
} else {
|
||||
logWarning(
|
||||
s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it")
|
||||
|
@ -690,6 +682,14 @@ case class AlterTableRecoverPartitionsCommand(
|
|||
Seq.empty
|
||||
}
|
||||
}
|
||||
val result = if (partitionNames.length > 1 &&
|
||||
statuses.length > threshold || partitionNames.length > 2) {
|
||||
parmap(statuses)(handleStatus _)
|
||||
} else {
|
||||
statuses.map(handleStatus)
|
||||
}
|
||||
|
||||
result.flatten
|
||||
}
|
||||
|
||||
private def gatherPartitionStats(
|
||||
|
|
|
@ -22,7 +22,6 @@ import java.net.URI
|
|||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
import scala.collection.parallel.ForkJoinTaskSupport
|
||||
import scala.util.{Failure, Try}
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
@ -532,30 +531,23 @@ object ParquetFileFormat extends Logging {
|
|||
conf: Configuration,
|
||||
partFiles: Seq[FileStatus],
|
||||
ignoreCorruptFiles: Boolean): Seq[Footer] = {
|
||||
val parFiles = partFiles.par
|
||||
val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8)
|
||||
parFiles.tasksupport = new ForkJoinTaskSupport(pool)
|
||||
try {
|
||||
parFiles.flatMap { currentFile =>
|
||||
try {
|
||||
// Skips row group information since we only need the schema.
|
||||
// ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
|
||||
// when it can't read the footer.
|
||||
Some(new Footer(currentFile.getPath(),
|
||||
ParquetFileReader.readFooter(
|
||||
conf, currentFile, SKIP_ROW_GROUPS)))
|
||||
} catch { case e: RuntimeException =>
|
||||
if (ignoreCorruptFiles) {
|
||||
logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
|
||||
None
|
||||
} else {
|
||||
throw new IOException(s"Could not read footer for file: $currentFile", e)
|
||||
}
|
||||
ThreadUtils.parmap(partFiles, "readingParquetFooters", 8) { currentFile =>
|
||||
try {
|
||||
// Skips row group information since we only need the schema.
|
||||
// ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
|
||||
// when it can't read the footer.
|
||||
Some(new Footer(currentFile.getPath(),
|
||||
ParquetFileReader.readFooter(
|
||||
conf, currentFile, SKIP_ROW_GROUPS)))
|
||||
} catch { case e: RuntimeException =>
|
||||
if (ignoreCorruptFiles) {
|
||||
logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
|
||||
None
|
||||
} else {
|
||||
throw new IOException(s"Could not read footer for file: $currentFile", e)
|
||||
}
|
||||
}.seq
|
||||
} finally {
|
||||
pool.shutdown()
|
||||
}
|
||||
}
|
||||
}.flatten
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -51,9 +51,9 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo
|
|||
}
|
||||
|
||||
testReadFooters(true)
|
||||
val exception = intercept[java.io.IOException] {
|
||||
val exception = intercept[SparkException] {
|
||||
testReadFooters(false)
|
||||
}
|
||||
}.getCause
|
||||
assert(exception.getMessage().contains("Could not read footer for file"))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -312,10 +312,10 @@ private[streaming] object FileBasedWriteAheadLog {
|
|||
handler: I => Iterator[O]): Iterator[O] = {
|
||||
val taskSupport = new ExecutionContextTaskSupport(executionContext)
|
||||
val groupSize = taskSupport.parallelismLevel.max(8)
|
||||
implicit val ec = executionContext
|
||||
|
||||
source.grouped(groupSize).flatMap { group =>
|
||||
val parallelCollection = group.par
|
||||
parallelCollection.tasksupport = taskSupport
|
||||
parallelCollection.map(handler)
|
||||
ThreadUtils.parmap(group)(handler)
|
||||
}.flatten
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue