[SPARK-24005][CORE] Remove usage of Scala’s parallel collection
## What changes were proposed in this pull request? In the PR, I propose to replace Scala parallel collections by new methods `parmap()`. The methods use futures to transform a sequential collection by applying a lambda function to each element in parallel. The result of `parmap` is another regular (sequential) collection. The proposed `parmap` method aims to solve the problem of impossibility to interrupt parallel Scala collection. This possibility is needed for reliable task preemption. ## How was this patch tested? A test was added to `ThreadUtilsSuite` Closes #21913 from MaxGekk/par-map. Authored-by: Maxim Gekk <maxim.gekk@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
88e0c7bbd5
commit
131ca146ed
|
@ -20,12 +20,13 @@ package org.apache.spark.rdd
|
||||||
import java.io.{IOException, ObjectOutputStream}
|
import java.io.{IOException, ObjectOutputStream}
|
||||||
|
|
||||||
import scala.collection.mutable.ArrayBuffer
|
import scala.collection.mutable.ArrayBuffer
|
||||||
import scala.collection.parallel.ForkJoinTaskSupport
|
import scala.concurrent.ExecutionContext
|
||||||
import scala.concurrent.forkjoin.ForkJoinPool
|
import scala.concurrent.forkjoin.ForkJoinPool
|
||||||
import scala.reflect.ClassTag
|
import scala.reflect.ClassTag
|
||||||
|
|
||||||
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
|
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
|
||||||
import org.apache.spark.annotation.DeveloperApi
|
import org.apache.spark.annotation.DeveloperApi
|
||||||
|
import org.apache.spark.util.ThreadUtils.parmap
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -59,8 +60,7 @@ private[spark] class UnionPartition[T: ClassTag](
|
||||||
}
|
}
|
||||||
|
|
||||||
object UnionRDD {
|
object UnionRDD {
|
||||||
private[spark] lazy val partitionEvalTaskSupport =
|
private[spark] lazy val threadPool = new ForkJoinPool(8)
|
||||||
new ForkJoinTaskSupport(new ForkJoinPool(8))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@DeveloperApi
|
@DeveloperApi
|
||||||
|
@ -74,14 +74,13 @@ class UnionRDD[T: ClassTag](
|
||||||
rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)
|
rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)
|
||||||
|
|
||||||
override def getPartitions: Array[Partition] = {
|
override def getPartitions: Array[Partition] = {
|
||||||
val parRDDs = if (isPartitionListingParallel) {
|
val partitionLengths = if (isPartitionListingParallel) {
|
||||||
val parArray = rdds.par
|
implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool)
|
||||||
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
|
parmap(rdds)(_.partitions.length)
|
||||||
parArray
|
|
||||||
} else {
|
} else {
|
||||||
rdds
|
rdds.map(_.partitions.length)
|
||||||
}
|
}
|
||||||
val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
|
val array = new Array[Partition](partitionLengths.sum)
|
||||||
var pos = 0
|
var pos = 0
|
||||||
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
|
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
|
||||||
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
|
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
|
||||||
|
|
|
@ -19,8 +19,12 @@ package org.apache.spark.util
|
||||||
|
|
||||||
import java.util.concurrent._
|
import java.util.concurrent._
|
||||||
|
|
||||||
|
import scala.collection.TraversableLike
|
||||||
|
import scala.collection.generic.CanBuildFrom
|
||||||
|
import scala.language.higherKinds
|
||||||
|
|
||||||
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
|
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
|
||||||
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor}
|
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future}
|
||||||
import scala.concurrent.duration.{Duration, FiniteDuration}
|
import scala.concurrent.duration.{Duration, FiniteDuration}
|
||||||
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
|
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
|
||||||
import scala.util.control.NonFatal
|
import scala.util.control.NonFatal
|
||||||
|
@ -254,4 +258,62 @@ private[spark] object ThreadUtils {
|
||||||
executor.shutdownNow()
|
executor.shutdownNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transforms input collection by applying the given function to each element in parallel fashion.
|
||||||
|
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
|
||||||
|
* at any time. This is useful on canceling of task execution, for example.
|
||||||
|
*
|
||||||
|
* @param in - the input collection which should be transformed in parallel.
|
||||||
|
* @param prefix - the prefix assigned to the underlying thread pool.
|
||||||
|
* @param maxThreads - maximum number of thread can be created during execution.
|
||||||
|
* @param f - the lambda function will be applied to each element of `in`.
|
||||||
|
* @tparam I - the type of elements in the input collection.
|
||||||
|
* @tparam O - the type of elements in resulted collection.
|
||||||
|
* @return new collection in which each element was given from the input collection `in` by
|
||||||
|
* applying the lambda function `f`.
|
||||||
|
*/
|
||||||
|
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
|
||||||
|
(in: Col[I], prefix: String, maxThreads: Int)
|
||||||
|
(f: I => O)
|
||||||
|
(implicit
|
||||||
|
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
|
||||||
|
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]] // for Future.sequence
|
||||||
|
): Col[O] = {
|
||||||
|
val pool = newForkJoinPool(prefix, maxThreads)
|
||||||
|
try {
|
||||||
|
implicit val ec = ExecutionContext.fromExecutor(pool)
|
||||||
|
|
||||||
|
parmap(in)(f)
|
||||||
|
} finally {
|
||||||
|
pool.shutdownNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transforms input collection by applying the given function to each element in parallel fashion.
|
||||||
|
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
|
||||||
|
* at any time. This is useful on canceling of task execution, for example.
|
||||||
|
*
|
||||||
|
* @param in - the input collection which should be transformed in parallel.
|
||||||
|
* @param f - the lambda function will be applied to each element of `in`.
|
||||||
|
* @param ec - an execution context for parallel applying of the given function `f`.
|
||||||
|
* @tparam I - the type of elements in the input collection.
|
||||||
|
* @tparam O - the type of elements in resulted collection.
|
||||||
|
* @return new collection in which each element was given from the input collection `in` by
|
||||||
|
* applying the lambda function `f`.
|
||||||
|
*/
|
||||||
|
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
|
||||||
|
(in: Col[I])
|
||||||
|
(f: I => O)
|
||||||
|
(implicit
|
||||||
|
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
|
||||||
|
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]], // for Future.sequence
|
||||||
|
ec: ExecutionContext
|
||||||
|
): Col[O] = {
|
||||||
|
val futures = in.map(x => Future(f(x)))
|
||||||
|
val futureSeq = Future.sequence(futures)
|
||||||
|
|
||||||
|
awaitResult(futureSeq, Duration.Inf)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,4 +133,37 @@ class ThreadUtilsSuite extends SparkFunSuite {
|
||||||
"stack trace contains unexpected references to ThreadUtils"
|
"stack trace contains unexpected references to ThreadUtils"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("parmap should be interruptible") {
|
||||||
|
val t = new Thread() {
|
||||||
|
setDaemon(true)
|
||||||
|
|
||||||
|
override def run() {
|
||||||
|
try {
|
||||||
|
// "par" is uninterruptible. The following will keep running even if the thread is
|
||||||
|
// interrupted. We should prefer to use "ThreadUtils.parmap".
|
||||||
|
//
|
||||||
|
// (1 to 10).par.flatMap { i =>
|
||||||
|
// Thread.sleep(100000)
|
||||||
|
// 1 to i
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
ThreadUtils.parmap(1 to 10, "test", 2) { i =>
|
||||||
|
Thread.sleep(100000)
|
||||||
|
1 to i
|
||||||
|
}.flatten
|
||||||
|
} catch {
|
||||||
|
case _: InterruptedException => // excepted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.start()
|
||||||
|
eventually(timeout(10.seconds)) {
|
||||||
|
assert(t.isAlive)
|
||||||
|
}
|
||||||
|
t.interrupt()
|
||||||
|
eventually(timeout(10.seconds)) {
|
||||||
|
assert(!t.isAlive)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command
|
||||||
import java.util.Locale
|
import java.util.Locale
|
||||||
|
|
||||||
import scala.collection.{GenMap, GenSeq}
|
import scala.collection.{GenMap, GenSeq}
|
||||||
import scala.collection.parallel.ForkJoinTaskSupport
|
import scala.concurrent.ExecutionContext
|
||||||
import scala.util.control.NonFatal
|
import scala.util.control.NonFatal
|
||||||
|
|
||||||
import org.apache.hadoop.conf.Configuration
|
import org.apache.hadoop.conf.Configuration
|
||||||
|
@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
|
||||||
|
|
||||||
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
|
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
|
||||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||||
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
|
import org.apache.spark.sql.catalyst.analysis.Resolver
|
||||||
import org.apache.spark.sql.catalyst.catalog._
|
import org.apache.spark.sql.catalyst.catalog._
|
||||||
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
|
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
||||||
|
@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
|
||||||
import org.apache.spark.sql.internal.HiveSerDe
|
import org.apache.spark.sql.internal.HiveSerDe
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
|
import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
|
||||||
|
import org.apache.spark.util.ThreadUtils.parmap
|
||||||
|
|
||||||
// Note: The definition of these commands are based on the ones described in
|
// Note: The definition of these commands are based on the ones described in
|
||||||
// https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
|
// https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
|
||||||
|
@ -621,8 +622,9 @@ case class AlterTableRecoverPartitionsCommand(
|
||||||
val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8)
|
val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8)
|
||||||
val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] =
|
val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] =
|
||||||
try {
|
try {
|
||||||
|
implicit val ec = ExecutionContext.fromExecutor(evalPool)
|
||||||
scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold,
|
scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold,
|
||||||
spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq
|
spark.sessionState.conf.resolver)
|
||||||
} finally {
|
} finally {
|
||||||
evalPool.shutdown()
|
evalPool.shutdown()
|
||||||
}
|
}
|
||||||
|
@ -654,23 +656,13 @@ case class AlterTableRecoverPartitionsCommand(
|
||||||
spec: TablePartitionSpec,
|
spec: TablePartitionSpec,
|
||||||
partitionNames: Seq[String],
|
partitionNames: Seq[String],
|
||||||
threshold: Int,
|
threshold: Int,
|
||||||
resolver: Resolver,
|
resolver: Resolver)(implicit ec: ExecutionContext): Seq[(TablePartitionSpec, Path)] = {
|
||||||
evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = {
|
|
||||||
if (partitionNames.isEmpty) {
|
if (partitionNames.isEmpty) {
|
||||||
return Seq(spec -> path)
|
return Seq(spec -> path)
|
||||||
}
|
}
|
||||||
|
|
||||||
val statuses = fs.listStatus(path, filter)
|
val statuses = fs.listStatus(path, filter).toSeq
|
||||||
val statusPar: GenSeq[FileStatus] =
|
def handleStatus(st: FileStatus): Seq[(TablePartitionSpec, Path)] = {
|
||||||
if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) {
|
|
||||||
// parallelize the list of partitions here, then we can have better parallelism later.
|
|
||||||
val parArray = statuses.par
|
|
||||||
parArray.tasksupport = evalTaskSupport
|
|
||||||
parArray
|
|
||||||
} else {
|
|
||||||
statuses
|
|
||||||
}
|
|
||||||
statusPar.flatMap { st =>
|
|
||||||
val name = st.getPath.getName
|
val name = st.getPath.getName
|
||||||
if (st.isDirectory && name.contains("=")) {
|
if (st.isDirectory && name.contains("=")) {
|
||||||
val ps = name.split("=", 2)
|
val ps = name.split("=", 2)
|
||||||
|
@ -679,7 +671,7 @@ case class AlterTableRecoverPartitionsCommand(
|
||||||
val value = ExternalCatalogUtils.unescapePathName(ps(1))
|
val value = ExternalCatalogUtils.unescapePathName(ps(1))
|
||||||
if (resolver(columnName, partitionNames.head)) {
|
if (resolver(columnName, partitionNames.head)) {
|
||||||
scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value),
|
scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value),
|
||||||
partitionNames.drop(1), threshold, resolver, evalTaskSupport)
|
partitionNames.drop(1), threshold, resolver)
|
||||||
} else {
|
} else {
|
||||||
logWarning(
|
logWarning(
|
||||||
s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it")
|
s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it")
|
||||||
|
@ -690,6 +682,14 @@ case class AlterTableRecoverPartitionsCommand(
|
||||||
Seq.empty
|
Seq.empty
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
val result = if (partitionNames.length > 1 &&
|
||||||
|
statuses.length > threshold || partitionNames.length > 2) {
|
||||||
|
parmap(statuses)(handleStatus _)
|
||||||
|
} else {
|
||||||
|
statuses.map(handleStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
result.flatten
|
||||||
}
|
}
|
||||||
|
|
||||||
private def gatherPartitionStats(
|
private def gatherPartitionStats(
|
||||||
|
|
|
@ -22,7 +22,6 @@ import java.net.URI
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
import scala.collection.parallel.ForkJoinTaskSupport
|
|
||||||
import scala.util.{Failure, Try}
|
import scala.util.{Failure, Try}
|
||||||
|
|
||||||
import org.apache.hadoop.conf.Configuration
|
import org.apache.hadoop.conf.Configuration
|
||||||
|
@ -532,30 +531,23 @@ object ParquetFileFormat extends Logging {
|
||||||
conf: Configuration,
|
conf: Configuration,
|
||||||
partFiles: Seq[FileStatus],
|
partFiles: Seq[FileStatus],
|
||||||
ignoreCorruptFiles: Boolean): Seq[Footer] = {
|
ignoreCorruptFiles: Boolean): Seq[Footer] = {
|
||||||
val parFiles = partFiles.par
|
ThreadUtils.parmap(partFiles, "readingParquetFooters", 8) { currentFile =>
|
||||||
val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8)
|
try {
|
||||||
parFiles.tasksupport = new ForkJoinTaskSupport(pool)
|
// Skips row group information since we only need the schema.
|
||||||
try {
|
// ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
|
||||||
parFiles.flatMap { currentFile =>
|
// when it can't read the footer.
|
||||||
try {
|
Some(new Footer(currentFile.getPath(),
|
||||||
// Skips row group information since we only need the schema.
|
ParquetFileReader.readFooter(
|
||||||
// ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
|
conf, currentFile, SKIP_ROW_GROUPS)))
|
||||||
// when it can't read the footer.
|
} catch { case e: RuntimeException =>
|
||||||
Some(new Footer(currentFile.getPath(),
|
if (ignoreCorruptFiles) {
|
||||||
ParquetFileReader.readFooter(
|
logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
|
||||||
conf, currentFile, SKIP_ROW_GROUPS)))
|
None
|
||||||
} catch { case e: RuntimeException =>
|
} else {
|
||||||
if (ignoreCorruptFiles) {
|
throw new IOException(s"Could not read footer for file: $currentFile", e)
|
||||||
logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
throw new IOException(s"Could not read footer for file: $currentFile", e)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}.seq
|
}
|
||||||
} finally {
|
}.flatten
|
||||||
pool.shutdown()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -51,9 +51,9 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo
|
||||||
}
|
}
|
||||||
|
|
||||||
testReadFooters(true)
|
testReadFooters(true)
|
||||||
val exception = intercept[java.io.IOException] {
|
val exception = intercept[SparkException] {
|
||||||
testReadFooters(false)
|
testReadFooters(false)
|
||||||
}
|
}.getCause
|
||||||
assert(exception.getMessage().contains("Could not read footer for file"))
|
assert(exception.getMessage().contains("Could not read footer for file"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -312,10 +312,10 @@ private[streaming] object FileBasedWriteAheadLog {
|
||||||
handler: I => Iterator[O]): Iterator[O] = {
|
handler: I => Iterator[O]): Iterator[O] = {
|
||||||
val taskSupport = new ExecutionContextTaskSupport(executionContext)
|
val taskSupport = new ExecutionContextTaskSupport(executionContext)
|
||||||
val groupSize = taskSupport.parallelismLevel.max(8)
|
val groupSize = taskSupport.parallelismLevel.max(8)
|
||||||
|
implicit val ec = executionContext
|
||||||
|
|
||||||
source.grouped(groupSize).flatMap { group =>
|
source.grouped(groupSize).flatMap { group =>
|
||||||
val parallelCollection = group.par
|
ThreadUtils.parmap(group)(handler)
|
||||||
parallelCollection.tasksupport = taskSupport
|
|
||||||
parallelCollection.map(handler)
|
|
||||||
}.flatten
|
}.flatten
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue