[SPARK-10708] Consolidate sort shuffle implementations

There's a lot of duplication between SortShuffleManager and UnsafeShuffleManager. Given that these now provide the same set of functionality, now that UnsafeShuffleManager supports large records, I think that we should replace SortShuffleManager's serialized shuffle implementation with UnsafeShuffleManager's and should merge the two managers together.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #8829 from JoshRosen/consolidate-sort-shuffle-implementations.
This commit is contained in:
Josh Rosen 2015-10-22 09:46:30 -07:00
parent 94e2064fa1
commit f6d06adf05
30 changed files with 456 additions and 1317 deletions

View file

@ -21,21 +21,30 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import javax.annotation.Nullable;
import scala.None$;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@ -62,7 +71,7 @@ import org.apache.spark.util.Utils;
* <p>
* There have been proposals to completely remove this code path; see SPARK-6026 for details.
*/
final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
private final BlockManager blockManager;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
private final Serializer serializer;
private final IndexShuffleBlockResolver shuffleBlockResolver;
/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
@Nullable private MapStatus mapStatus;
private long[] partitionLengths;
/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
* and then call stop() with success = false if they get an exception, we want to make sure
* we don't try deleting files, etc twice.
*/
private boolean stopping = false;
public BypassMergeSortShuffleWriter(
SparkConf conf,
BlockManager blockManager,
Partitioner partitioner,
ShuffleWriteMetrics writeMetrics,
Serializer serializer) {
IndexShuffleBlockResolver shuffleBlockResolver,
BypassMergeSortShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
SparkConf conf) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
this.numPartitions = partitioner.numPartitions();
this.blockManager = blockManager;
this.partitioner = partitioner;
this.writeMetrics = writeMetrics;
this.serializer = serializer;
final ShuffleDependency<K, V, V> dep = handle.dependency();
this.mapId = mapId;
this.shuffleId = dep.shuffleId();
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = new ShuffleWriteMetrics();
taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
this.serializer = Serializer.getSerializer(dep.serializer());
this.shuffleBlockResolver = shuffleBlockResolver;
}
@Override
public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
@ -124,13 +154,24 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
for (DiskBlockObjectWriter writer : partitionWriters) {
writer.commitAndClose();
}
partitionLengths =
writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
@Override
public long[] writePartitionedFile(
BlockId blockId,
TaskContext context,
File outputFile) throws IOException {
@VisibleForTesting
long[] getPartitionLengths() {
return partitionLengths;
}
/**
* Concatenate all of the per-partition files into a single combined file.
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
private long[] writePartitionedFile(File outputFile) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
@ -165,18 +206,33 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
}
@Override
public void stop() throws IOException {
if (partitionWriters != null) {
try {
for (DiskBlockObjectWriter writer : partitionWriters) {
// This method explicitly does _not_ throw exceptions:
File file = writer.revertPartialWritesAndClose();
if (!file.delete()) {
logger.error("Error while deleting file {}", file.getAbsolutePath());
public Option<MapStatus> stop(boolean success) {
if (stopping) {
return None$.empty();
} else {
stopping = true;
if (success) {
if (mapStatus == null) {
throw new IllegalStateException("Cannot call stop(true) without having called write()");
}
return Option.apply(mapStatus);
} else {
// The map task failed, so delete our output data.
if (partitionWriters != null) {
try {
for (DiskBlockObjectWriter writer : partitionWriters) {
// This method explicitly does _not_ throw exceptions:
File file = writer.revertPartialWritesAndClose();
if (!file.delete()) {
logger.error("Error while deleting file {}", file.getAbsolutePath());
}
}
} finally {
partitionWriters = null;
}
}
} finally {
partitionWriters = null;
shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
return None$.empty();
}
}
}

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;
/**
* Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;
import javax.annotation.Nullable;
import java.io.File;
@ -48,7 +48,7 @@ import org.apache.spark.util.Utils;
* <p>
* Incoming records are appended to data pages. When all records have been inserted (or when the
* current thread's shuffle memory limit is reached), the in-memory records are sorted according to
* their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
* their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then
* written to a single output file (or multiple files, if we've spilled). The format of the output
* files is the same as the format of the final output file written by
* {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
@ -59,9 +59,9 @@ import org.apache.spark.util.Utils;
* spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
* specialized merge procedure that avoids extra serialization/deserialization.
*/
final class UnsafeShuffleExternalSorter {
final class ShuffleExternalSorter {
private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
@VisibleForTesting
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
@ -76,6 +76,10 @@ final class UnsafeShuffleExternalSorter {
private final BlockManager blockManager;
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
private long numRecordsInsertedSinceLastSpill = 0;
/** Force this sorter to spill when there are this many elements in memory. For testing only */
private final long numElementsForSpillThreshold;
/** The buffer size to use when writing spills using DiskBlockObjectWriter */
private final int fileBufferSizeBytes;
@ -94,12 +98,12 @@ final class UnsafeShuffleExternalSorter {
private long peakMemoryUsedBytes;
// These variables are reset after spilling:
@Nullable private UnsafeShuffleInMemorySorter inMemSorter;
@Nullable private ShuffleInMemorySorter inMemSorter;
@Nullable private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;
public UnsafeShuffleExternalSorter(
public ShuffleExternalSorter(
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
@ -117,6 +121,8 @@ final class UnsafeShuffleExternalSorter {
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.pageSizeBytes = (int) Math.min(
PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
this.maxRecordSizeBytes = pageSizeBytes - 4;
@ -140,7 +146,8 @@ final class UnsafeShuffleExternalSorter {
throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
}
this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize);
this.inMemSorter = new ShuffleInMemorySorter(initialSize);
numRecordsInsertedSinceLastSpill = 0;
}
/**
@ -166,7 +173,7 @@ final class UnsafeShuffleExternalSorter {
}
// This call performs the actual sort.
final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
inMemSorter.getSortedIterator();
// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
@ -406,6 +413,10 @@ final class UnsafeShuffleExternalSorter {
int lengthInBytes,
int partitionId) throws IOException {
if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
spill();
}
growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
final int totalSpaceRequired = lengthInBytes + 4;
@ -453,6 +464,7 @@ final class UnsafeShuffleExternalSorter {
recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
assert(inMemSorter != null);
inMemSorter.insertRecord(recordAddress, partitionId);
numRecordsInsertedSinceLastSpill += 1;
}
/**

View file

@ -15,13 +15,13 @@
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;
import java.util.Comparator;
import org.apache.spark.util.collection.Sorter;
final class UnsafeShuffleInMemorySorter {
final class ShuffleInMemorySorter {
private final Sorter<PackedRecordPointer, long[]> sorter;
private static final class SortComparator implements Comparator<PackedRecordPointer> {
@ -44,10 +44,10 @@ final class UnsafeShuffleInMemorySorter {
*/
private int pointerArrayInsertPosition = 0;
public UnsafeShuffleInMemorySorter(int initialSize) {
public ShuffleInMemorySorter(int initialSize) {
assert (initialSize > 0);
this.pointerArray = new long[initialSize];
this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
}
public void expandPointerArray() {
@ -92,14 +92,14 @@ final class UnsafeShuffleInMemorySorter {
/**
* An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
*/
public static final class UnsafeShuffleSorterIterator {
public static final class ShuffleSorterIterator {
private final long[] pointerArray;
private final int numRecords;
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
private int position = 0;
public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
this.numRecords = numRecords;
this.pointerArray = pointerArray;
}
@ -117,8 +117,8 @@ final class UnsafeShuffleInMemorySorter {
/**
* Return an iterator over record pointers in sorted order.
*/
public UnsafeShuffleSorterIterator getSortedIterator() {
public ShuffleSorterIterator getSortedIterator() {
sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
}
}

View file

@ -15,15 +15,15 @@
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;
import org.apache.spark.util.collection.SortDataFormat;
final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
private UnsafeShuffleSortDataFormat() { }
private ShuffleSortDataFormat() { }
@Override
public PackedRecordPointer getKey(long[] data, int pos) {

View file

@ -1,53 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.sort;
import java.io.File;
import java.io.IOException;
import scala.Product2;
import scala.collection.Iterator;
import org.apache.spark.annotation.Private;
import org.apache.spark.TaskContext;
import org.apache.spark.storage.BlockId;
/**
* Interface for objects that {@link SortShuffleWriter} uses to write its output files.
*/
@Private
public interface SortShuffleFileWriter<K, V> {
void insertAll(Iterator<Product2<K, V>> records) throws IOException;
/**
* Write all the data added into this shuffle sorter into a file in the disk store. This is
* called by the SortShuffleWriter and can go through an efficient path of just concatenating
* binary files if we decided to avoid merge-sorting.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
long[] writePartitionedFile(
BlockId blockId,
TaskContext context,
File outputFile) throws IOException;
void stop() throws IOException;
}

View file

@ -15,14 +15,14 @@
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;
import java.io.File;
import org.apache.spark.storage.TempShuffleBlockId;
/**
* Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}.
* Metadata for a block of data written by {@link ShuffleExternalSorter}.
*/
final class SpillInfo {
final long[] partitionLengths;

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;
import javax.annotation.Nullable;
import java.io.*;
@ -80,7 +80,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final boolean transferToEnabled;
@Nullable private MapStatus mapStatus;
@Nullable private UnsafeShuffleExternalSorter sorter;
@Nullable private ShuffleExternalSorter sorter;
private long peakMemoryUsedBytes = 0;
/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
@ -104,15 +104,15 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
IndexShuffleBlockResolver shuffleBlockResolver,
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
UnsafeShuffleHandle<K, V> handle,
SerializedShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
SparkConf sparkConf) throws IOException {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) {
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
"UnsafeShuffleWriter can only be used for shuffles with at most " +
UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions");
SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions");
}
this.blockManager = blockManager;
this.shuffleBlockResolver = shuffleBlockResolver;
@ -195,7 +195,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private void open() throws IOException {
assert (sorter == null);
sorter = new UnsafeShuffleExternalSorter(
sorter = new ShuffleExternalSorter(
memoryManager,
shuffleMemoryManager,
blockManager,

View file

@ -330,7 +330,7 @@ object SparkEnv extends Logging {
val shortShuffleMgrNames = Map(
"hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
"sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager",
"tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager")
"tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)

View file

@ -19,9 +19,53 @@ package org.apache.spark.shuffle.sort
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency}
import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
/**
* In sort-based shuffle, incoming records are sorted according to their target partition ids, then
* written to a single map output file. Reducers fetch contiguous regions of this file in order to
* read their portion of the map output. In cases where the map output data is too large to fit in
* memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
* to produce the final output file.
*
* Sort-based shuffle has two different write paths for producing its map output files:
*
* - Serialized sorting: used when all three of the following conditions hold:
* 1. The shuffle dependency specifies no aggregation or output ordering.
* 2. The shuffle serializer supports relocation of serialized values (this is currently
* supported by KryoSerializer and Spark SQL's custom serializers).
* 3. The shuffle produces fewer than 16777216 output partitions.
* - Deserialized sorting: used to handle all other cases.
*
* -----------------------
* Serialized sorting mode
* -----------------------
*
* In the serialized sorting mode, incoming records are serialized as soon as they are passed to the
* shuffle writer and are buffered in a serialized form during sorting. This write path implements
* several optimizations:
*
* - Its sort operates on serialized binary data rather than Java objects, which reduces memory
* consumption and GC overheads. This optimization requires the record serializer to have certain
* properties to allow serialized records to be re-ordered without requiring deserialization.
* See SPARK-4550, where this optimization was first proposed and implemented, for more details.
*
* - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts
* arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
* record in the sorting array, this fits more of the array into cache.
*
* - The spill merging procedure operates on blocks of serialized records that belong to the same
* partition and does not need to deserialize records during the merge.
*
* - When the spill compression codec supports concatenation of compressed data, the spill merge
* simply concatenates the serialized and compressed spill partitions to produce the final output
* partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used
* and avoids the need to allocate decompression or copying buffers during the merge.
*
* For more details on these optimizations, see SPARK-7081.
*/
private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
if (!conf.getBoolean("spark.shuffle.spill", true)) {
@ -30,8 +74,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
" Shuffle will continue to spill to disk when necessary.")
}
private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf)
private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
/**
* A mapping from shuffle ids to the number of mappers producing output for those shuffles.
*/
private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
/**
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
@ -40,7 +88,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
new BaseShuffleHandle(shuffleId, numMaps, dependency)
if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need map-side aggregation, then write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleHandle[K, V](
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
// Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
new SerializedShuffleHandle[K, V](
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
// Otherwise, buffer map outputs in a deserialized form:
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}
}
/**
@ -52,38 +115,114 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
// We currently use the same block store shuffle fetcher as the hash-based shuffle.
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
: ShuffleWriter[K, V] = {
val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]]
shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps)
new SortShuffleWriter(
shuffleBlockResolver, baseShuffleHandle, mapId, context)
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Int,
context: TaskContext): ShuffleWriter[K, V] = {
numMapsForShuffle.putIfAbsent(
handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
val env = SparkEnv.get
handle match {
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
context.taskMemoryManager(),
env.shuffleMemoryManager,
unsafeShuffleHandle,
mapId,
context,
env.conf)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
bypassMergeSortHandle,
mapId,
context,
env.conf)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
}
}
/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Boolean = {
if (shuffleMapNumber.containsKey(shuffleId)) {
val numMaps = shuffleMapNumber.remove(shuffleId)
(0 until numMaps).map{ mapId =>
Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps =>
(0 until numMaps).foreach { mapId =>
shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
}
}
true
}
override val shuffleBlockResolver: IndexShuffleBlockResolver = {
indexShuffleBlockResolver
}
/** Shut down this ShuffleManager. */
override def stop(): Unit = {
shuffleBlockResolver.stop()
}
}
private[spark] object SortShuffleManager extends Logging {
/**
* The maximum number of shuffle output partitions that SortShuffleManager supports when
* buffering map outputs in a serialized form. This is an extreme defensive programming measure,
* since it's extremely unlikely that a single shuffle produces over 16 million output partitions.
* */
val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE =
PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
/**
* Helper method for determining whether a shuffle should use an optimized serialized shuffle
* path or whether it should fall back to the original path that operates on deserialized objects.
*/
def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
val shufId = dependency.shuffleId
val numPartitions = dependency.partitioner.numPartitions
val serializer = Serializer.getSerializer(dependency.serializer)
if (!serializer.supportsRelocationOfSerializedObjects) {
log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
s"${serializer.getClass.getName}, does not support object relocation")
false
} else if (dependency.aggregator.isDefined) {
log.debug(
s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined")
false
} else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
false
} else {
log.debug(s"Can use serialized shuffle for shuffle $shufId")
true
}
}
}
/**
* Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
* serialized shuffle.
*/
private[spark] class SerializedShuffleHandle[K, V](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, V])
extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
}
/**
* Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
* bypass merge sort shuffle path.
*/
private[spark] class BypassMergeSortShuffleHandle[K, V](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, V])
extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
}

View file

@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
@ -36,7 +35,7 @@ private[spark] class SortShuffleWriter[K, V, C](
private val blockManager = SparkEnv.get.blockManager
private var sorter: SortShuffleFileWriter[K, V] = null
private var sorter: ExternalSorter[K, V, _] = null
// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
@ -54,15 +53,6 @@ private[spark] class SortShuffleWriter[K, V, C](
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else if (SortShuffleWriter.shouldBypassMergeSort(
SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need local aggregation and sorting, write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
writeMetrics, Serializer.getSerializer(dep.serializer))
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
@ -111,12 +101,14 @@ private[spark] class SortShuffleWriter[K, V, C](
}
private[spark] object SortShuffleWriter {
def shouldBypassMergeSort(
conf: SparkConf,
numPartitions: Int,
aggregator: Option[Aggregator[_, _, _]],
keyOrdering: Option[Ordering[_]]): Boolean = {
val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty
def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
// We cannot bypass sorting if we need to do map-side aggregation.
if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
false
} else {
val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
dep.partitioner.numPartitions <= bypassMergeThreshold
}
}
}

View file

@ -1,202 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe
import java.util.Collections
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.sort.SortShuffleManager
/**
* Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
*/
private[spark] class UnsafeShuffleHandle[K, V](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, V])
extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
}
private[spark] object UnsafeShuffleManager extends Logging {
/**
* The maximum number of shuffle output partitions that UnsafeShuffleManager supports.
*/
val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
/**
* Helper method for determining whether a shuffle should use the optimized unsafe shuffle
* path or whether it should fall back to the original sort-based shuffle.
*/
def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
val shufId = dependency.shuffleId
val serializer = Serializer.getSerializer(dependency.serializer)
if (!serializer.supportsRelocationOfSerializedObjects) {
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
s"${serializer.getClass.getName}, does not support object relocation")
false
} else if (dependency.aggregator.isDefined) {
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
false
} else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")
false
} else {
log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
true
}
}
}
/**
* A shuffle implementation that uses directly-managed memory to implement several performance
* optimizations for certain types of shuffles. In cases where the new performance optimizations
* cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those
* shuffles.
*
* UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:
*
* - The shuffle dependency specifies no aggregation or output ordering.
* - The shuffle serializer supports relocation of serialized values (this is currently supported
* by KryoSerializer and Spark SQL's custom serializers).
* - The shuffle produces fewer than 16777216 output partitions.
* - No individual record is larger than 128 MB when serialized.
*
* In addition, extra spill-merging optimizations are automatically applied when the shuffle
* compression codec supports concatenation of serialized streams. This is currently supported by
* Spark's LZF serializer.
*
* At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.
* In sort-based shuffle, incoming records are sorted according to their target partition ids, then
* written to a single map output file. Reducers fetch contiguous regions of this file in order to
* read their portion of the map output. In cases where the map output data is too large to fit in
* memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
* to produce the final output file.
*
* UnsafeShuffleManager optimizes this process in several ways:
*
* - Its sort operates on serialized binary data rather than Java objects, which reduces memory
* consumption and GC overheads. This optimization requires the record serializer to have certain
* properties to allow serialized records to be re-ordered without requiring deserialization.
* See SPARK-4550, where this optimization was first proposed and implemented, for more details.
*
* - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts
* arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
* record in the sorting array, this fits more of the array into cache.
*
* - The spill merging procedure operates on blocks of serialized records that belong to the same
* partition and does not need to deserialize records during the merge.
*
* - When the spill compression codec supports concatenation of compressed data, the spill merge
* simply concatenates the serialized and compressed spill partitions to produce the final output
* partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used
* and avoids the need to allocate decompression or copying buffers during the merge.
*
* For more details on UnsafeShuffleManager's design, see SPARK-7081.
*/
private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
if (!conf.getBoolean("spark.shuffle.spill", true)) {
logWarning(
"spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " +
"manager; its optimized shuffles will continue to spill to disk when necessary.")
}
private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
private[this] val shufflesThatFellBackToSortShuffle =
Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
/**
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
*/
override def registerShuffle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
new UnsafeShuffleHandle[K, V](
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}
}
/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
* Called on executors by reduce tasks.
*/
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
sortShuffleManager.getReader(handle, startPartition, endPartition, context)
}
/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Int,
context: TaskContext): ShuffleWriter[K, V] = {
handle match {
case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] =>
numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
val env = SparkEnv.get
new UnsafeShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
context.taskMemoryManager(),
env.shuffleMemoryManager,
unsafeShuffleHandle,
mapId,
context,
env.conf)
case other =>
shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
sortShuffleManager.getWriter(handle, mapId, context)
}
}
/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Boolean = {
if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
sortShuffleManager.unregisterShuffle(shuffleId)
} else {
Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
(0 until numMaps).foreach { mapId =>
shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
}
}
true
}
}
override val shuffleBlockResolver: IndexShuffleBlockResolver = {
sortShuffleManager.shuffleBlockResolver
}
/** Shut down this ShuffleManager. */
override def stop(): Unit = {
sortShuffleManager.stop()
}
}

View file

@ -1,146 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.collection
import java.io.OutputStream
import scala.collection.mutable.ArrayBuffer
/**
* A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The
* advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts
* of memory and needing to copy the full contents. The disadvantage is that the contents don't
* occupy a contiguous segment of memory.
*/
private[spark] class ChainedBuffer(chunkSize: Int) {
private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros(
java.lang.Long.highestOneBit(chunkSize))
assert((1 << chunkSizeLog2) == chunkSize,
s"ChainedBuffer chunk size $chunkSize must be a power of two")
private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
private var _size: Long = 0
/**
* Feed bytes from this buffer into a DiskBlockObjectWriter.
*
* @param pos Offset in the buffer to read from.
* @param os OutputStream to read into.
* @param len Number of bytes to read.
*/
def read(pos: Long, os: OutputStream, len: Int): Unit = {
if (pos + len > _size) {
throw new IndexOutOfBoundsException(
s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
}
var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
var written: Int = 0
while (written < len) {
val toRead: Int = math.min(len - written, chunkSize - posInChunk)
os.write(chunks(chunkIndex), posInChunk, toRead)
written += toRead
chunkIndex += 1
posInChunk = 0
}
}
/**
* Read bytes from this buffer into a byte array.
*
* @param pos Offset in the buffer to read from.
* @param bytes Byte array to read into.
* @param offs Offset in the byte array to read to.
* @param len Number of bytes to read.
*/
def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
if (pos + len > _size) {
throw new IndexOutOfBoundsException(
s"Read of $len bytes at position $pos would go past size of buffer")
}
var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
var written: Int = 0
while (written < len) {
val toRead: Int = math.min(len - written, chunkSize - posInChunk)
System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
written += toRead
chunkIndex += 1
posInChunk = 0
}
}
/**
* Write bytes from a byte array into this buffer.
*
* @param pos Offset in the buffer to write to.
* @param bytes Byte array to write from.
* @param offs Offset in the byte array to write from.
* @param len Number of bytes to write.
*/
def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
if (pos > _size) {
throw new IndexOutOfBoundsException(
s"Write at position $pos starts after end of buffer ${_size}")
}
// Grow if needed
val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt
while (endChunkIndex >= chunks.length) {
chunks += new Array[Byte](chunkSize)
}
var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
var written: Int = 0
while (written < len) {
val toWrite: Int = math.min(len - written, chunkSize - posInChunk)
System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
written += toWrite
chunkIndex += 1
posInChunk = 0
}
_size = math.max(_size, pos + len)
}
/**
* Total size of buffer that can be written to without allocating additional memory.
*/
def capacity: Long = chunks.size.toLong * chunkSize
/**
* Size of the logical buffer.
*/
def size: Long = _size
}
/**
* Output stream that writes to a ChainedBuffer.
*/
private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
private var pos: Long = 0
override def write(b: Int): Unit = {
throw new UnsupportedOperationException()
}
override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
chainedBuffer.write(pos, bytes, offs, len)
pos += len
}
}

View file

@ -29,7 +29,6 @@ import com.google.common.io.ByteStreams
import org.apache.spark._
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
/**
@ -69,8 +68,8 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
* At a high level, this class works internally as follows:
*
* - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if
* we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we
* don't. Inside these buffers, we sort elements by partition ID and then possibly also by key.
* we want to combine by key, or a PartitionedPairBuffer if we don't.
* Inside these buffers, we sort elements by partition ID and then possibly also by key.
* To avoid calling the partitioner multiple times with each key, we store the partition ID
* alongside each record.
*
@ -93,8 +92,7 @@ private[spark] class ExternalSorter[K, V, C](
ordering: Option[Ordering[K]] = None,
serializer: Option[Serializer] = None)
extends Logging
with Spillable[WritablePartitionedPairCollection[K, C]]
with SortShuffleFileWriter[K, V] {
with Spillable[WritablePartitionedPairCollection[K, C]] {
private val conf = SparkEnv.get.conf
@ -104,13 +102,6 @@ private[spark] class ExternalSorter[K, V, C](
if (shouldPartition) partitioner.get.getPartition(key) else 0
}
// Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
// As a sanity check, make sure that we're not handling a shuffle which should use that path.
if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
throw new IllegalArgumentException("ExternalSorter should not be used to handle "
+ " a sort that the BypassMergeSortShuffleWriter should handle")
}
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
private val ser = Serializer.getSerializer(serializer)
@ -128,23 +119,11 @@ private[spark] class ExternalSorter[K, V, C](
// grow internal data structures by growing + copying every time the number of objects doubles.
private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
private val useSerializedPairBuffer =
ordering.isEmpty &&
conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
ser.supportsRelocationOfSerializedObjects
private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
if (useSerializedPairBuffer) {
new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
} else {
new PartitionedPairBuffer[K, C]
}
}
// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
private var map = new PartitionedAppendOnlyMap[K, C]
private var buffer = newBuffer()
private var buffer = new PartitionedPairBuffer[K, C]
// Total spilling statistics
private var _diskBytesSpilled = 0L
@ -192,7 +171,7 @@ private[spark] class ExternalSorter[K, V, C](
*/
private[spark] def numSpills: Int = spills.size
override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined
@ -236,7 +215,7 @@ private[spark] class ExternalSorter[K, V, C](
} else {
estimatedSize = buffer.estimateSize()
if (maybeSpill(buffer, estimatedSize)) {
buffer = newBuffer()
buffer = new PartitionedPairBuffer[K, C]
}
}
@ -659,7 +638,7 @@ private[spark] class ExternalSorter[K, V, C](
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
override def writePartitionedFile(
def writePartitionedFile(
blockId: BlockId,
context: TaskContext,
outputFile: File): Array[Long] = {

View file

@ -1,273 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.collection
import java.io.InputStream
import java.nio.IntBuffer
import java.util.Comparator
import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
import org.apache.spark.storage.DiskBlockObjectWriter
import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
/**
* Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes
* its records upon insert and stores them as raw bytes.
*
* We use two data-structures to store the contents. The serialized records are stored in a
* ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a
* metadata buffer that stores pointers into the data buffer as well as the partition ID of each
* record. Each entry in the metadata buffer takes up a fixed amount of space.
*
* Sorting the collection means swapping entries in the metadata buffer - the record buffer need not
* be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can
* happen without following any pointers, which should minimize cache misses.
*
* Currently, only sorting by partition is supported.
*
* Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across
* two integers:
*
* +-------------+------------+------------+-------------+
* | keyStart | keyValLen | partitionId |
* +-------------+------------+------------+-------------+
*
* The buffer can support up to `536870911 (2 ^ 29 - 1)` records.
*
* @param metaInitialRecords The initial number of entries in the metadata buffer.
* @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records.
* @param serializerInstance the serializer used for serializing inserted records.
*/
private[spark] class PartitionedSerializedPairBuffer[K, V](
metaInitialRecords: Int,
kvBlockSize: Int,
serializerInstance: SerializerInstance)
extends WritablePartitionedPairCollection[K, V] with SizeTracker {
if (serializerInstance.isInstanceOf[JavaSerializerInstance]) {
throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" +
" Java-serialized objects.")
}
require(metaInitialRecords <= MAXIMUM_RECORDS,
s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records")
private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE)
private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize)
private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer)
private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream)
def insert(partition: Int, key: K, value: V): Unit = {
if (metaBuffer.position == metaBuffer.capacity) {
growMetaBuffer()
}
val keyStart = kvBuffer.size
kvSerializationStream.writeKey[Any](key)
kvSerializationStream.writeValue[Any](value)
kvSerializationStream.flush()
val keyValLen = (kvBuffer.size - keyStart).toInt
// keyStart, a long, gets split across two ints
metaBuffer.put(keyStart.toInt)
metaBuffer.put((keyStart >> 32).toInt)
metaBuffer.put(keyValLen)
metaBuffer.put(partition)
}
/** Double the size of the array because we've reached capacity */
private def growMetaBuffer(): Unit = {
if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) {
throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records")
}
val newCapacity =
if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) {
// Overflow
MAXIMUM_META_BUFFER_CAPACITY
} else {
metaBuffer.capacity * 2
}
val newMetaBuffer = IntBuffer.allocate(newCapacity)
newMetaBuffer.put(metaBuffer.array)
metaBuffer = newMetaBuffer
}
/** Iterate through the data in a given order. For this class this is not really destructive. */
override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
: Iterator[((Int, K), V)] = {
sort(keyComparator)
val is = orderedInputStream
val deserStream = serializerInstance.deserializeStream(is)
new Iterator[((Int, K), V)] {
var metaBufferPos = 0
def hasNext: Boolean = metaBufferPos < metaBuffer.position
def next(): ((Int, K), V) = {
val key = deserStream.readKey[Any]().asInstanceOf[K]
val value = deserStream.readValue[Any]().asInstanceOf[V]
val partition = metaBuffer.get(metaBufferPos + PARTITION)
metaBufferPos += RECORD_SIZE
((partition, key), value)
}
}
}
override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity
override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
: WritablePartitionedIterator = {
sort(keyComparator)
new WritablePartitionedIterator {
// current position in the meta buffer in ints
var pos = 0
def writeNext(writer: DiskBlockObjectWriter): Unit = {
val keyStart = getKeyStartPos(metaBuffer, pos)
val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
pos += RECORD_SIZE
kvBuffer.read(keyStart, writer, keyValLen)
writer.recordWritten()
}
def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
def hasNext(): Boolean = pos < metaBuffer.position
}
}
// Visible for testing
def orderedInputStream: OrderedInputStream = {
new OrderedInputStream(metaBuffer, kvBuffer)
}
private def sort(keyComparator: Option[Comparator[K]]): Unit = {
val comparator = if (keyComparator.isEmpty) {
new Comparator[Int]() {
def compare(partition1: Int, partition2: Int): Int = {
partition1 - partition2
}
}
} else {
throw new UnsupportedOperationException()
}
val sorter = new Sorter(new SerializedSortDataFormat)
sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator)
}
}
private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer)
extends InputStream {
import PartitionedSerializedPairBuffer._
private var metaBufferPos = 0
private var kvBufferPos =
if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0
override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
override def read(bytes: Array[Byte], offs: Int, len: Int): Int = {
if (metaBufferPos >= metaBuffer.position) {
return -1
}
val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) -
(kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt
val toRead = math.min(bytesRemainingInRecord, len)
kvBuffer.read(kvBufferPos, bytes, offs, toRead)
if (toRead == bytesRemainingInRecord) {
metaBufferPos += RECORD_SIZE
if (metaBufferPos < metaBuffer.position) {
kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos)
}
} else {
kvBufferPos += toRead
}
toRead
}
override def read(): Int = {
throw new UnsupportedOperationException()
}
}
private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] {
private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE)
/** Return the sort key for the element at the given index. */
override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = {
metaBuffer.get(pos * RECORD_SIZE + PARTITION)
}
/** Swap two elements. */
override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = {
val iOff = pos0 * RECORD_SIZE
val jOff = pos1 * RECORD_SIZE
System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE)
System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE)
System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE)
}
/** Copy a single element from src(srcPos) to dst(dstPos). */
override def copyElement(
src: IntBuffer,
srcPos: Int,
dst: IntBuffer,
dstPos: Int): Unit = {
val srcOff = srcPos * RECORD_SIZE
val dstOff = dstPos * RECORD_SIZE
System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE)
}
/**
* Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
* Overlapping ranges are allowed.
*/
override def copyRange(
src: IntBuffer,
srcPos: Int,
dst: IntBuffer,
dstPos: Int,
length: Int): Unit = {
val srcOff = srcPos * RECORD_SIZE
val dstOff = dstPos * RECORD_SIZE
System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length)
}
/**
* Allocates a Buffer that can hold up to 'length' elements.
* All elements of the buffer should be considered invalid until data is explicitly copied in.
*/
override def allocate(length: Int): IntBuffer = {
IntBuffer.allocate(length * RECORD_SIZE)
}
}
private object PartitionedSerializedPairBuffer {
val KEY_START = 0 // keyStart, a long, gets split across two ints
val KEY_VAL_LEN = 2
val PARTITION = 3
val RECORD_SIZE = PARTITION + 1 // num ints of metadata
val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1
val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4
def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = {
val lower32 = metaBuffer.get(metaBufferPos + KEY_START)
val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1)
(upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL)
}
}

View file

@ -15,8 +15,9 @@
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;
import org.apache.spark.shuffle.sort.PackedRecordPointer;
import org.junit.Test;
import static org.junit.Assert.*;
@ -24,7 +25,7 @@ import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
public class PackedRecordPointerSuite {

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;
import java.util.Arrays;
import java.util.Random;
@ -30,7 +30,7 @@ import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
public class UnsafeShuffleInMemorySorterSuite {
public class ShuffleInMemorySorterSuite {
private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
final byte[] strBytes = new byte[strLength];
@ -40,8 +40,8 @@ public class UnsafeShuffleInMemorySorterSuite {
@Test
public void testSortingEmptyInput() {
final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100);
final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100);
final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
assert(!iter.hasNext());
}
@ -62,7 +62,7 @@ public class UnsafeShuffleInMemorySorterSuite {
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
final MemoryBlock dataPage = memoryManager.allocatePage(2048);
final Object baseObject = dataPage.getBaseObject();
final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Write the records into the data page and store pointers into the sorter
@ -79,7 +79,7 @@ public class UnsafeShuffleInMemorySorterSuite {
}
// Sort the records
final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
int prevPartitionId = -1;
Arrays.sort(dataToSort);
for (int i = 0; i < dataToSort.length; i++) {
@ -103,7 +103,7 @@ public class UnsafeShuffleInMemorySorterSuite {
@Test
public void testSortingManyNumbers() throws Exception {
UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
int[] numbersToSort = new int[128000];
Random random = new Random(16);
for (int i = 0; i < numbersToSort.length; i++) {
@ -112,7 +112,7 @@ public class UnsafeShuffleInMemorySorterSuite {
}
Arrays.sort(numbersToSort);
int[] sorterResult = new int[numbersToSort.length];
UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
int j = 0;
while (iter.hasNext()) {
iter.loadNext();

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;
import java.io.*;
import java.nio.ByteBuffer;
@ -23,7 +23,6 @@ import java.util.*;
import scala.*;
import scala.collection.Iterator;
import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction1;
import com.google.common.collect.Iterators;
@ -56,6 +55,7 @@ import org.apache.spark.serializer.*;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
import org.apache.spark.unsafe.memory.MemoryAllocator;
@ -204,7 +204,7 @@ public class UnsafeShuffleWriterSuite {
shuffleBlockResolver,
taskMemoryManager,
shuffleMemoryManager,
new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
0, // map id
taskContext,
conf
@ -461,7 +461,7 @@ public class UnsafeShuffleWriterSuite {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite =
new ArrayList<Product2<Object, Object>>();
final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
new Random(42).nextBytes(bytes);
dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
writer.write(dataToWrite.iterator());
@ -516,7 +516,7 @@ public class UnsafeShuffleWriterSuite {
shuffleBlockResolver,
taskMemoryManager,
shuffleMemoryManager,
new UnsafeShuffleHandle<>(0, 1, shuffleDep),
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf);

View file

@ -17,13 +17,78 @@
package org.apache.spark
import java.io.File
import scala.collection.JavaConverters._
import org.apache.commons.io.FileUtils
import org.apache.commons.io.filefilter.TrueFileFilter
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.util.Utils
class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with sort-based shuffle.
private var tempDir: File = _
override def beforeAll() {
conf.set("spark.shuffle.manager", "sort")
}
override def beforeEach(): Unit = {
tempDir = Utils.createTempDir()
conf.set("spark.local.dir", tempDir.getAbsolutePath)
}
override def afterEach(): Unit = {
try {
Utils.deleteRecursively(tempDir)
} finally {
super.afterEach()
}
}
test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") {
sc = new SparkContext("local", "test", conf)
// Create a shuffled RDD and verify that it actually uses the new serialized map output path
val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
.setSerializer(new KryoSerializer(conf))
val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep))
ensureFilesAreCleanedUp(shuffledRdd)
}
test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") {
sc = new SparkContext("local", "test", conf)
// Create a shuffled RDD and verify that it actually uses the old deserialized map output path
val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
.setSerializer(new JavaSerializer(conf))
val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep))
ensureFilesAreCleanedUp(shuffledRdd)
}
private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = {
def getAllFiles: Set[File] =
FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
val filesBeforeShuffle = getAllFiles
// Force the shuffle to be performed
shuffledRdd.count()
// Ensure that the shuffle actually created files that will need to be cleaned up
val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
filesCreatedByShuffle.map(_.getName) should be
Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
// Check that the cleanup actually removes the files
sc.env.blockManager.master.removeShuffle(0, blocking = true)
for (file <- filesCreatedByShuffle) {
assert (!file.exists(), s"Shuffle file $file was not cleaned up")
}
}
}

View file

@ -1062,10 +1062,10 @@ class DAGSchedulerSuite
*/
test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") {
val firstRDD = new MyRDD(sc, 3, Nil)
val firstShuffleDep = new ShuffleDependency(firstRDD, null)
val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
val firstShuffleId = firstShuffleDep.shuffleId
val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep))
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
submit(reduceRdd, Array(0))
@ -1175,7 +1175,7 @@ class DAGSchedulerSuite
*/
test("register map outputs correctly after ExecutorLost and task Resubmitted") {
val firstRDD = new MyRDD(sc, 3, Nil)
val firstShuffleDep = new ShuffleDependency(firstRDD, null)
val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep))
submit(reduceRdd, Array(0))

View file

@ -33,7 +33,8 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark._
import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics}
import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer}
import org.apache.spark.shuffle.IndexShuffleBlockResolver
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance}
import org.apache.spark.storage._
import org.apache.spark.util.Utils
@ -42,25 +43,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
@Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
@Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
@Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
@Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _
@Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _
private var taskMetrics: TaskMetrics = _
private var shuffleWriteMetrics: ShuffleWriteMetrics = _
private var tempDir: File = _
private var outputFile: File = _
private val conf: SparkConf = new SparkConf(loadDefaults = false)
private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0)
private val serializer: Serializer = new JavaSerializer(conf)
private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _
override def beforeEach(): Unit = {
tempDir = Utils.createTempDir()
outputFile = File.createTempFile("shuffle", null, tempDir)
shuffleWriteMetrics = new ShuffleWriteMetrics
taskMetrics = new TaskMetrics
taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
MockitoAnnotations.initMocks(this)
shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int](
shuffleId = 0,
numMaps = 2,
dependency = dependency
)
when(dependency.partitioner).thenReturn(new HashPartitioner(7))
when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(blockManager.getDiskWriter(
any[BlockId],
@ -107,18 +114,20 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
test("write empty iterator") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
new SparkConf(loadDefaults = false),
blockManager,
new HashPartitioner(7),
shuffleWriteMetrics,
serializer
blockResolver,
shuffleHandle,
0, // MapId
taskContext,
conf
)
writer.insertAll(Iterator.empty)
val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
assert(partitionLengths.sum === 0)
writer.write(Iterator.empty)
writer.stop( /* success = */ true)
assert(writer.getPartitionLengths.sum === 0)
assert(outputFile.exists())
assert(outputFile.length() === 0)
assert(temporaryFilesCreated.isEmpty)
val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get
assert(shuffleWriteMetrics.shuffleBytesWritten === 0)
assert(shuffleWriteMetrics.shuffleRecordsWritten === 0)
assert(taskMetrics.diskBytesSpilled === 0)
@ -129,17 +138,19 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
def records: Iterator[(Int, Int)] =
Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
val writer = new BypassMergeSortShuffleWriter[Int, Int](
new SparkConf(loadDefaults = false),
blockManager,
new HashPartitioner(7),
shuffleWriteMetrics,
serializer
blockResolver,
shuffleHandle,
0, // MapId
taskContext,
conf
)
writer.insertAll(records)
writer.write(records)
writer.stop( /* success = */ true)
assert(temporaryFilesCreated.nonEmpty)
val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
assert(partitionLengths.sum === outputFile.length())
assert(writer.getPartitionLengths.sum === outputFile.length())
assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get
assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length())
assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length)
assert(taskMetrics.diskBytesSpilled === 0)
@ -148,14 +159,15 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
test("cleanup of intermediate files after errors") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
new SparkConf(loadDefaults = false),
blockManager,
new HashPartitioner(7),
shuffleWriteMetrics,
serializer
blockResolver,
shuffleHandle,
0, // MapId
taskContext,
conf
)
intercept[SparkException] {
writer.insertAll((0 until 100000).iterator.map(i => {
writer.write((0 until 100000).iterator.map(i => {
if (i == 99990) {
throw new SparkException("Intentional failure")
}
@ -163,7 +175,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
}))
}
assert(temporaryFilesCreated.nonEmpty)
writer.stop()
writer.stop( /* success = */ false)
assert(temporaryFilesCreated.count(_.exists()) === 0)
}

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe
package org.apache.spark.shuffle.sort
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
@ -29,9 +29,9 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
* Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
* performed in other suites.
*/
class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
import UnsafeShuffleManager.canUseUnsafeShuffle
import SortShuffleManager.canUseSerializedShuffle
private class RuntimeExceptionAnswer extends Answer[Object] {
override def answer(invocation: InvocationOnMock): Object = {
@ -55,10 +55,10 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
dep
}
test("supported shuffle dependencies") {
test("supported shuffle dependencies for serialized shuffle") {
val kryo = Some(new KryoSerializer(new SparkConf()))
assert(canUseUnsafeShuffle(shuffleDep(
assert(canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = kryo,
keyOrdering = None,
@ -68,7 +68,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
when(rangePartitioner.numPartitions).thenReturn(2)
assert(canUseUnsafeShuffle(shuffleDep(
assert(canUseSerializedShuffle(shuffleDep(
partitioner = rangePartitioner,
serializer = kryo,
keyOrdering = None,
@ -77,7 +77,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
)))
// Shuffles with key orderings are supported as long as no aggregator is specified
assert(canUseUnsafeShuffle(shuffleDep(
assert(canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = kryo,
keyOrdering = Some(mock(classOf[Ordering[Any]])),
@ -87,12 +87,12 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
}
test("unsupported shuffle dependencies") {
test("unsupported shuffle dependencies for serialized shuffle") {
val kryo = Some(new KryoSerializer(new SparkConf()))
val java = Some(new JavaSerializer(new SparkConf()))
// We only support serializers that support object relocation
assert(!canUseUnsafeShuffle(shuffleDep(
assert(!canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = java,
keyOrdering = None,
@ -100,9 +100,11 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
mapSideCombine = false
)))
// We do not support shuffles with more than 16 million output partitions
assert(!canUseUnsafeShuffle(shuffleDep(
partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1),
// The serialized shuffle path do not support shuffles with more than 16 million output
// partitions, due to a limitation in its sorter implementation.
assert(!canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(
SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1),
serializer = kryo,
keyOrdering = None,
aggregator = None,
@ -110,14 +112,14 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
)))
// We do not support shuffles that perform aggregation
assert(!canUseUnsafeShuffle(shuffleDep(
assert(!canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = kryo,
keyOrdering = None,
aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
mapSideCombine = false
)))
assert(!canUseUnsafeShuffle(shuffleDep(
assert(!canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = kryo,
keyOrdering = Some(mock(classOf[Ordering[Any]])),

View file

@ -1,45 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.sort
import org.mockito.Mockito._
import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite}
class SortShuffleWriterSuite extends SparkFunSuite {
import SortShuffleWriter._
test("conditions for bypassing merge-sort") {
val conf = new SparkConf(loadDefaults = false)
val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS)
val ord = implicitly[Ordering[Int]]
// Numbers of partitions that are above and below the default bypassMergeThreshold
val FEW_PARTITIONS = 50
val MANY_PARTITIONS = 10000
// Shuffles with no ordering or aggregator: should bypass unless # of partitions is high
assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None))
assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None))
// Shuffles with an ordering or aggregator: should not bypass even if they have few partitions
assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord)))
assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None))
}
}

View file

@ -1,102 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.unsafe
import java.io.File
import scala.collection.JavaConverters._
import org.apache.commons.io.FileUtils
import org.apache.commons.io.filefilter.TrueFileFilter
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.util.Utils
class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
override def beforeAll() {
conf.set("spark.shuffle.manager", "tungsten-sort")
}
test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
val tmpDir = Utils.createTempDir()
try {
val myConf = conf.clone()
.set("spark.local.dir", tmpDir.getAbsolutePath)
sc = new SparkContext("local", "test", myConf)
// Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
.setSerializer(new KryoSerializer(myConf))
val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
def getAllFiles: Set[File] =
FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
val filesBeforeShuffle = getAllFiles
// Force the shuffle to be performed
shuffledRdd.count()
// Ensure that the shuffle actually created files that will need to be cleaned up
val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
filesCreatedByShuffle.map(_.getName) should be
Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
// Check that the cleanup actually removes the files
sc.env.blockManager.master.removeShuffle(0, blocking = true)
for (file <- filesCreatedByShuffle) {
assert (!file.exists(), s"Shuffle file $file was not cleaned up")
}
} finally {
Utils.deleteRecursively(tmpDir)
}
}
test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
val tmpDir = Utils.createTempDir()
try {
val myConf = conf.clone()
.set("spark.local.dir", tmpDir.getAbsolutePath)
sc = new SparkContext("local", "test", myConf)
// Create a shuffled RDD and verify that it will actually use the old SortShuffle path
val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
.setSerializer(new JavaSerializer(myConf))
val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
def getAllFiles: Set[File] =
FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
val filesBeforeShuffle = getAllFiles
// Force the shuffle to be performed
shuffledRdd.count()
// Ensure that the shuffle actually created files that will need to be cleaned up
val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
filesCreatedByShuffle.map(_.getName) should be
Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
// Check that the cleanup actually removes the files
sc.env.blockManager.master.removeShuffle(0, blocking = true)
for (file <- filesCreatedByShuffle) {
assert (!file.exists(), s"Shuffle file $file was not cleaned up")
}
} finally {
Utils.deleteRecursively(tmpDir)
}
}
}

View file

@ -1,144 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.collection
import java.nio.ByteBuffer
import org.scalatest.Matchers._
import org.apache.spark.SparkFunSuite
class ChainedBufferSuite extends SparkFunSuite {
test("write and read at start") {
// write from start of source array
val buffer = new ChainedBuffer(8)
buffer.capacity should be (0)
verifyWriteAndRead(buffer, 0, 0, 0, 4)
buffer.capacity should be (8)
// write from middle of source array
verifyWriteAndRead(buffer, 0, 5, 0, 4)
buffer.capacity should be (8)
// read to middle of target array
verifyWriteAndRead(buffer, 0, 0, 5, 4)
buffer.capacity should be (8)
// write up to border
verifyWriteAndRead(buffer, 0, 0, 0, 8)
buffer.capacity should be (8)
// expand into second buffer
verifyWriteAndRead(buffer, 0, 0, 0, 12)
buffer.capacity should be (16)
// expand into multiple buffers
verifyWriteAndRead(buffer, 0, 0, 0, 28)
buffer.capacity should be (32)
}
test("write and read at middle") {
val buffer = new ChainedBuffer(8)
// fill to a middle point
verifyWriteAndRead(buffer, 0, 0, 0, 3)
// write from start of source array
verifyWriteAndRead(buffer, 3, 0, 0, 4)
buffer.capacity should be (8)
// write from middle of source array
verifyWriteAndRead(buffer, 3, 5, 0, 4)
buffer.capacity should be (8)
// read to middle of target array
verifyWriteAndRead(buffer, 3, 0, 5, 4)
buffer.capacity should be (8)
// write up to border
verifyWriteAndRead(buffer, 3, 0, 0, 5)
buffer.capacity should be (8)
// expand into second buffer
verifyWriteAndRead(buffer, 3, 0, 0, 12)
buffer.capacity should be (16)
// expand into multiple buffers
verifyWriteAndRead(buffer, 3, 0, 0, 28)
buffer.capacity should be (32)
}
test("write and read at later buffer") {
val buffer = new ChainedBuffer(8)
// fill to a middle point
verifyWriteAndRead(buffer, 0, 0, 0, 11)
// write from start of source array
verifyWriteAndRead(buffer, 11, 0, 0, 4)
buffer.capacity should be (16)
// write from middle of source array
verifyWriteAndRead(buffer, 11, 5, 0, 4)
buffer.capacity should be (16)
// read to middle of target array
verifyWriteAndRead(buffer, 11, 0, 5, 4)
buffer.capacity should be (16)
// write up to border
verifyWriteAndRead(buffer, 11, 0, 0, 5)
buffer.capacity should be (16)
// expand into second buffer
verifyWriteAndRead(buffer, 11, 0, 0, 12)
buffer.capacity should be (24)
// expand into multiple buffers
verifyWriteAndRead(buffer, 11, 0, 0, 28)
buffer.capacity should be (40)
}
// Used to make sure we're writing different bytes each time
var rangeStart = 0
/**
* @param buffer The buffer to write to and read from.
* @param offsetInBuffer The offset to write to in the buffer.
* @param offsetInSource The offset in the array that the bytes are written from.
* @param offsetInTarget The offset in the array to read the bytes into.
* @param length The number of bytes to read and write
*/
def verifyWriteAndRead(
buffer: ChainedBuffer,
offsetInBuffer: Int,
offsetInSource: Int,
offsetInTarget: Int,
length: Int): Unit = {
val source = new Array[Byte](offsetInSource + length)
(rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource)
buffer.write(offsetInBuffer, source, offsetInSource, length)
val target = new Array[Byte](offsetInTarget + length)
buffer.read(offsetInBuffer, target, offsetInTarget, length)
ByteBuffer.wrap(source, offsetInSource, length) should be
(ByteBuffer.wrap(target, offsetInTarget, length))
rangeStart += 100
}
}

View file

@ -1,148 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.collection
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import com.google.common.io.ByteStreams
import org.mockito.Matchers.any
import org.mockito.Mockito._
import org.mockito.Mockito.RETURNS_SMART_NULLS
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.Matchers._
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.storage.DiskBlockObjectWriter
class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
test("OrderedInputStream single record") {
val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
val struct = SomeStruct("something", 5)
buffer.insert(4, 10, struct)
val bytes = ByteStreams.toByteArray(buffer.orderedInputStream)
val baos = new ByteArrayOutputStream()
val stream = serializerInstance.serializeStream(baos)
stream.writeObject(10)
stream.writeObject(struct)
stream.close()
baos.toByteArray should be (bytes)
}
test("insert single record") {
val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
val struct = SomeStruct("something", 5)
buffer.insert(4, 10, struct)
val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
elements.size should be (1)
elements.head should be (((4, 10), struct))
}
test("insert multiple records") {
val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
val struct1 = SomeStruct("something1", 8)
buffer.insert(6, 1, struct1)
val struct2 = SomeStruct("something2", 9)
buffer.insert(4, 2, struct2)
val struct3 = SomeStruct("something3", 10)
buffer.insert(5, 3, struct3)
val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
elements.size should be (3)
elements(0) should be (((4, 2), struct2))
elements(1) should be (((5, 3), struct3))
elements(2) should be (((6, 1), struct1))
}
test("write single record") {
val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
val struct = SomeStruct("something", 5)
buffer.insert(4, 10, struct)
val it = buffer.destructiveSortedWritablePartitionedIterator(None)
val (writer, baos) = createMockWriter()
assert(it.hasNext)
it.nextPartition should be (4)
it.writeNext(writer)
assert(!it.hasNext)
val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
stream.readObject[AnyRef]() should be (10)
stream.readObject[AnyRef]() should be (struct)
}
test("write multiple records") {
val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
val struct1 = SomeStruct("something1", 8)
buffer.insert(6, 1, struct1)
val struct2 = SomeStruct("something2", 9)
buffer.insert(4, 2, struct2)
val struct3 = SomeStruct("something3", 10)
buffer.insert(5, 3, struct3)
val it = buffer.destructiveSortedWritablePartitionedIterator(None)
val (writer, baos) = createMockWriter()
assert(it.hasNext)
it.nextPartition should be (4)
it.writeNext(writer)
assert(it.hasNext)
it.nextPartition should be (5)
it.writeNext(writer)
assert(it.hasNext)
it.nextPartition should be (6)
it.writeNext(writer)
assert(!it.hasNext)
val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
val iter = stream.asIterator
iter.next() should be (2)
iter.next() should be (struct2)
iter.next() should be (3)
iter.next() should be (struct3)
iter.next() should be (1)
iter.next() should be (struct1)
assert(!iter.hasNext)
}
def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = {
val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS)
val baos = new ByteArrayOutputStream()
when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocationOnMock: InvocationOnMock): Unit = {
val args = invocationOnMock.getArguments
val bytes = args(0).asInstanceOf[Array[Byte]]
val offset = args(1).asInstanceOf[Int]
val length = args(2).asInstanceOf[Int]
baos.write(bytes, offset, length)
}
})
(writer, baos)
}
}
case class SomeStruct(str: String, num: Int)

View file

@ -437,12 +437,9 @@ Apart from these, the following properties are also available, and may be useful
<td><code>spark.shuffle.manager</code></td>
<td>sort</td>
<td>
Implementation to use for shuffling data. There are three implementations available:
<code>sort</code>, <code>hash</code> and the new (1.5+) <code>tungsten-sort</code>.
Implementation to use for shuffling data. There are two implementations available:
<code>sort</code> and <code>hash</code>.
Sort-based shuffle is more memory-efficient and is the default option starting in 1.2.
Tungsten-sort is similar to the sort based shuffle, with a direct binary cache-friendly
implementation with a fall back to regular sort based shuffle if its requirements are not
met.
</td>
</tr>
<tr>

View file

@ -37,6 +37,7 @@ object MimaExcludes {
Seq(
MimaBuild.excludeSparkPackage("deploy"),
MimaBuild.excludeSparkPackage("network"),
MimaBuild.excludeSparkPackage("unsafe"),
// These are needed if checking against the sbt build, since they are part of
// the maven-generated artifacts in 1.3.
excludePackage("org.spark-project.jetty"),
@ -44,7 +45,11 @@ object MimaExcludes {
// SQL execution is considered private.
excludePackage("org.apache.spark.sql.execution"),
// SQL columnar is considered private.
excludePackage("org.apache.spark.sql.columnar")
excludePackage("org.apache.spark.sql.columnar"),
// The shuffle package is considered private.
excludePackage("org.apache.spark.shuffle"),
// The collections utlities are considered pricate.
excludePackage("org.apache.spark.util.collection")
) ++
MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++
MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++
@ -750,4 +755,4 @@ object MimaExcludes {
MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD")
case _ => Seq()
}
}
}

View file

@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
@ -87,10 +86,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
// fewer partitions (like RangePartitioner, for example).
val conf = child.sqlContext.sparkContext.conf
val shuffleManager = SparkEnv.get.shuffleManager
val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] ||
shuffleManager.isInstanceOf[UnsafeShuffleManager]
val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true)
if (sortBasedShuffleOn) {
val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
@ -99,22 +96,18 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
// doesn't buffer deserialized records.
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
false
} else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
// SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting
// them. This optimization is guarded by a feature-flag and is only applied in cases where
// shuffle dependency does not specify an aggregator or ordering and the record serializer
// has certain properties. If this optimization is enabled, we can safely avoid the copy.
} else if (serializer.supportsRelocationOfSerializedObjects) {
// SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records
// prior to sorting them. This optimization is only applied in cases where shuffle
// dependency does not specify an aggregator or ordering and the record serializer has
// certain properties. If this optimization is enabled, we can safely avoid the copy.
//
// Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only
// need to check whether the optimization is enabled and supported by our serializer.
//
// This optimization also applies to UnsafeShuffleManager (added in SPARK-7081).
false
} else {
// Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code
// path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls
// back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In
// both cases, we must copy.
// Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must
// copy.
true
}
} else if (shuffleManager.isInstanceOf[HashShuffleManager]) {

View file

@ -101,7 +101,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten
Utils.tryWithSafeFinally {
val conf = new SparkConf()
.set("spark.shuffle.spill.initialMemoryThreshold", "1024")
.set("spark.shuffle.spill.initialMemoryThreshold", "1")
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
.set("spark.testing.memory", "80000")
@ -109,7 +109,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
// prepare data
val converter = unsafeRowConverter(Array(IntegerType))
val data = (1 to 1000).iterator.map { i =>
val data = (1 to 10000).iterator.map { i =>
(i, converter(Row(i)))
}
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
@ -141,9 +141,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}
}
test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") {
val conf = new SparkConf()
.set("spark.shuffle.manager", "tungsten-sort")
test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
val conf = new SparkConf().set("spark.shuffle.manager", "sort")
sc = new SparkContext("local", "test", conf)
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))