[SPARK-24005][CORE] Remove usage of Scala’s parallel collection

## What changes were proposed in this pull request?

In the PR, I propose to replace Scala parallel collections by new methods `parmap()`. The methods use futures to transform a sequential collection by applying a lambda function to each element in parallel. The result of `parmap` is another regular (sequential) collection.

The proposed `parmap` method aims to solve the problem of impossibility to interrupt parallel Scala collection. This possibility is needed for reliable task preemption.

## How was this patch tested?

A test was added to `ThreadUtilsSuite`

Closes #21913 from MaxGekk/par-map.

Authored-by: Maxim Gekk <maxim.gekk@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Maxim Gekk 2018-08-07 17:14:30 +08:00 committed by Wenchen Fan
parent 88e0c7bbd5
commit 131ca146ed
7 changed files with 142 additions and 56 deletions

View file

@ -20,12 +20,13 @@ package org.apache.spark.rdd
import java.io.{IOException, ObjectOutputStream} import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.ExecutionContext
import scala.concurrent.forkjoin.ForkJoinPool import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ThreadUtils.parmap
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
/** /**
@ -59,8 +60,7 @@ private[spark] class UnionPartition[T: ClassTag](
} }
object UnionRDD { object UnionRDD {
private[spark] lazy val partitionEvalTaskSupport = private[spark] lazy val threadPool = new ForkJoinPool(8)
new ForkJoinTaskSupport(new ForkJoinPool(8))
} }
@DeveloperApi @DeveloperApi
@ -74,14 +74,13 @@ class UnionRDD[T: ClassTag](
rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10) rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)
override def getPartitions: Array[Partition] = { override def getPartitions: Array[Partition] = {
val parRDDs = if (isPartitionListingParallel) { val partitionLengths = if (isPartitionListingParallel) {
val parArray = rdds.par implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool)
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport parmap(rdds)(_.partitions.length)
parArray
} else { } else {
rdds rdds.map(_.partitions.length)
} }
val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum) val array = new Array[Partition](partitionLengths.sum)
var pos = 0 var pos = 0
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)

View file

@ -19,8 +19,12 @@ package org.apache.spark.util
import java.util.concurrent._ import java.util.concurrent._
import scala.collection.TraversableLike
import scala.collection.generic.CanBuildFrom
import scala.language.higherKinds
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future}
import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
import scala.util.control.NonFatal import scala.util.control.NonFatal
@ -254,4 +258,62 @@ private[spark] object ThreadUtils {
executor.shutdownNow() executor.shutdownNow()
} }
} }
/**
* Transforms input collection by applying the given function to each element in parallel fashion.
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
* at any time. This is useful on canceling of task execution, for example.
*
* @param in - the input collection which should be transformed in parallel.
* @param prefix - the prefix assigned to the underlying thread pool.
* @param maxThreads - maximum number of thread can be created during execution.
* @param f - the lambda function will be applied to each element of `in`.
* @tparam I - the type of elements in the input collection.
* @tparam O - the type of elements in resulted collection.
* @return new collection in which each element was given from the input collection `in` by
* applying the lambda function `f`.
*/
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
(in: Col[I], prefix: String, maxThreads: Int)
(f: I => O)
(implicit
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]] // for Future.sequence
): Col[O] = {
val pool = newForkJoinPool(prefix, maxThreads)
try {
implicit val ec = ExecutionContext.fromExecutor(pool)
parmap(in)(f)
} finally {
pool.shutdownNow()
}
}
/**
* Transforms input collection by applying the given function to each element in parallel fashion.
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
* at any time. This is useful on canceling of task execution, for example.
*
* @param in - the input collection which should be transformed in parallel.
* @param f - the lambda function will be applied to each element of `in`.
* @param ec - an execution context for parallel applying of the given function `f`.
* @tparam I - the type of elements in the input collection.
* @tparam O - the type of elements in resulted collection.
* @return new collection in which each element was given from the input collection `in` by
* applying the lambda function `f`.
*/
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
(in: Col[I])
(f: I => O)
(implicit
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]], // for Future.sequence
ec: ExecutionContext
): Col[O] = {
val futures = in.map(x => Future(f(x)))
val futureSeq = Future.sequence(futures)
awaitResult(futureSeq, Duration.Inf)
}
} }

View file

@ -133,4 +133,37 @@ class ThreadUtilsSuite extends SparkFunSuite {
"stack trace contains unexpected references to ThreadUtils" "stack trace contains unexpected references to ThreadUtils"
) )
} }
test("parmap should be interruptible") {
val t = new Thread() {
setDaemon(true)
override def run() {
try {
// "par" is uninterruptible. The following will keep running even if the thread is
// interrupted. We should prefer to use "ThreadUtils.parmap".
//
// (1 to 10).par.flatMap { i =>
// Thread.sleep(100000)
// 1 to i
// }
//
ThreadUtils.parmap(1 to 10, "test", 2) { i =>
Thread.sleep(100000)
1 to i
}.flatten
} catch {
case _: InterruptedException => // excepted
}
}
}
t.start()
eventually(timeout(10.seconds)) {
assert(t.isAlive)
}
t.interrupt()
eventually(timeout(10.seconds)) {
assert(!t.isAlive)
}
}
} }

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command
import java.util.Locale import java.util.Locale
import scala.collection.{GenMap, GenSeq} import scala.collection.{GenMap, GenSeq}
import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver} import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
import org.apache.spark.util.ThreadUtils.parmap
// Note: The definition of these commands are based on the ones described in // Note: The definition of these commands are based on the ones described in
// https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
@ -621,8 +622,9 @@ case class AlterTableRecoverPartitionsCommand(
val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8)
val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] = val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] =
try { try {
implicit val ec = ExecutionContext.fromExecutor(evalPool)
scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold,
spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq spark.sessionState.conf.resolver)
} finally { } finally {
evalPool.shutdown() evalPool.shutdown()
} }
@ -654,23 +656,13 @@ case class AlterTableRecoverPartitionsCommand(
spec: TablePartitionSpec, spec: TablePartitionSpec,
partitionNames: Seq[String], partitionNames: Seq[String],
threshold: Int, threshold: Int,
resolver: Resolver, resolver: Resolver)(implicit ec: ExecutionContext): Seq[(TablePartitionSpec, Path)] = {
evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = {
if (partitionNames.isEmpty) { if (partitionNames.isEmpty) {
return Seq(spec -> path) return Seq(spec -> path)
} }
val statuses = fs.listStatus(path, filter) val statuses = fs.listStatus(path, filter).toSeq
val statusPar: GenSeq[FileStatus] = def handleStatus(st: FileStatus): Seq[(TablePartitionSpec, Path)] = {
if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) {
// parallelize the list of partitions here, then we can have better parallelism later.
val parArray = statuses.par
parArray.tasksupport = evalTaskSupport
parArray
} else {
statuses
}
statusPar.flatMap { st =>
val name = st.getPath.getName val name = st.getPath.getName
if (st.isDirectory && name.contains("=")) { if (st.isDirectory && name.contains("=")) {
val ps = name.split("=", 2) val ps = name.split("=", 2)
@ -679,7 +671,7 @@ case class AlterTableRecoverPartitionsCommand(
val value = ExternalCatalogUtils.unescapePathName(ps(1)) val value = ExternalCatalogUtils.unescapePathName(ps(1))
if (resolver(columnName, partitionNames.head)) { if (resolver(columnName, partitionNames.head)) {
scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value),
partitionNames.drop(1), threshold, resolver, evalTaskSupport) partitionNames.drop(1), threshold, resolver)
} else { } else {
logWarning( logWarning(
s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it") s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it")
@ -690,6 +682,14 @@ case class AlterTableRecoverPartitionsCommand(
Seq.empty Seq.empty
} }
} }
val result = if (partitionNames.length > 1 &&
statuses.length > threshold || partitionNames.length > 2) {
parmap(statuses)(handleStatus _)
} else {
statuses.map(handleStatus)
}
result.flatten
} }
private def gatherPartitionStats( private def gatherPartitionStats(

View file

@ -22,7 +22,6 @@ import java.net.URI
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.collection.mutable import scala.collection.mutable
import scala.collection.parallel.ForkJoinTaskSupport
import scala.util.{Failure, Try} import scala.util.{Failure, Try}
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
@ -532,30 +531,23 @@ object ParquetFileFormat extends Logging {
conf: Configuration, conf: Configuration,
partFiles: Seq[FileStatus], partFiles: Seq[FileStatus],
ignoreCorruptFiles: Boolean): Seq[Footer] = { ignoreCorruptFiles: Boolean): Seq[Footer] = {
val parFiles = partFiles.par ThreadUtils.parmap(partFiles, "readingParquetFooters", 8) { currentFile =>
val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8) try {
parFiles.tasksupport = new ForkJoinTaskSupport(pool) // Skips row group information since we only need the schema.
try { // ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
parFiles.flatMap { currentFile => // when it can't read the footer.
try { Some(new Footer(currentFile.getPath(),
// Skips row group information since we only need the schema. ParquetFileReader.readFooter(
// ParquetFileReader.readFooter throws RuntimeException, instead of IOException, conf, currentFile, SKIP_ROW_GROUPS)))
// when it can't read the footer. } catch { case e: RuntimeException =>
Some(new Footer(currentFile.getPath(), if (ignoreCorruptFiles) {
ParquetFileReader.readFooter( logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
conf, currentFile, SKIP_ROW_GROUPS))) None
} catch { case e: RuntimeException => } else {
if (ignoreCorruptFiles) { throw new IOException(s"Could not read footer for file: $currentFile", e)
logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
None
} else {
throw new IOException(s"Could not read footer for file: $currentFile", e)
}
} }
}.seq }
} finally { }.flatten
pool.shutdown()
}
} }
/** /**

View file

@ -51,9 +51,9 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo
} }
testReadFooters(true) testReadFooters(true)
val exception = intercept[java.io.IOException] { val exception = intercept[SparkException] {
testReadFooters(false) testReadFooters(false)
} }.getCause
assert(exception.getMessage().contains("Could not read footer for file")) assert(exception.getMessage().contains("Could not read footer for file"))
} }
} }

View file

@ -312,10 +312,10 @@ private[streaming] object FileBasedWriteAheadLog {
handler: I => Iterator[O]): Iterator[O] = { handler: I => Iterator[O]): Iterator[O] = {
val taskSupport = new ExecutionContextTaskSupport(executionContext) val taskSupport = new ExecutionContextTaskSupport(executionContext)
val groupSize = taskSupport.parallelismLevel.max(8) val groupSize = taskSupport.parallelismLevel.max(8)
implicit val ec = executionContext
source.grouped(groupSize).flatMap { group => source.grouped(groupSize).flatMap { group =>
val parallelCollection = group.par ThreadUtils.parmap(group)(handler)
parallelCollection.tasksupport = taskSupport
parallelCollection.map(handler)
}.flatten }.flatten
} }
} }