Spark shuffle writer

本篇介绍ShuffleWriter的原理。shuffle任务的运行,请参看此文章

DiskBlockObjectWriter

我们首先来看一下DiskBlockObjectWriter,接下来要介绍的三种ShuffleWriter都使用了这个类,DiskBlockObjectWriter类将Java对象直接写入到磁盘上的文件,它封装并包装了文件流。


1
2
3
4
5
6
7
8
private def initialize(): Unit = {
fos = new FileOutputStream(file, true)
channel = fos.getChannel()
ts = new TimeTrackingOutputStream(writeMetrics, fos)
class ManualCloseBufferedOutputStream
extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream
mcs = new ManualCloseBufferedOutputStream
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def open(): DiskBlockObjectWriter = {
if (hasBeenClosed) {
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
if (!initialized) {
initialize()
initialized = true
}

bs = serializerManager.wrapStream(blockId, mcs) // 对文件流进行压缩和加密
objOut = serializerInstance.serializeStream(bs) // 对文件流进行序列化
streamOpen = true
this
}

1
2
3
4
5
6
7
8
9
10
11
// 压缩格式由spark.io.compression.codec进行配置
private def shouldCompress(blockId: BlockId): Boolean = {
blockId match {
case _: ShuffleBlockId => compressShuffle // spark.shuffle.compress
case _: BroadcastBlockId => compressBroadcast //spark.broadcast.compress
case _: RDDBlockId => compressRdds // spark.rdd.compress
case _: TempLocalBlockId => compressShuffleSpill // spark.shuffle.spill.compress
case _: TempShuffleBlockId => compressShuffle // spark.shuffle.compress
case _ => false
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// 该方法将KV对以java对象的形式进行写入,
// 注意,该方法使用了objOut写入,它包装了序列化和压缩的逻辑
// BypassMergeSortShuffleWriter和SortShuffleWriter使用了这个方法
def write(key: Any, value: Any) {
if (!streamOpen) {
open()
}

objOut.writeKey(key)
objOut.writeValue(value)
recordWritten()
}

// 该方法直接将字节数组写入流,
// 注意,该方法使用了bs写入,它包装只有压缩的逻辑,不进行序列化,
// UnsafeShuffleWriter使用了这个方法
override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
if (!streamOpen) {
open()
}

bs.write(kvBytes, offs, len)
}

BypassMergeSortShuffleWriter原理

BypassMergeSortShuffleWriter不利用Spark执行缓存,根据输入记录的key,将其直接写入到单独文件,每个文件对应一个reduce分区,最后将这些文件按照分区ID依序拼接起来形成最终的输出文件。

优点:由序列化写入的临时shuffle文件拼接最终文件的时候,不需要解序列化,直接按字节流copy数据,性能比较高。

缺点:这个ShuffleWriter在有大量reduce分区时,性能不高,因为它会为所有reduce分区同时打开序列化器和文件流。


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
final long openStartTime = System.nanoTime();
partitionWriters = new DiskBlockObjectWriter[numPartitions];
partitionWriterSegments = new FileSegment[numPartitions];
for (int i = 0; i < numPartitions; i++) {
// 为每个分区分别创建临时分区文件和shuffle block
final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = tempShuffleBlockIdPlusFile._2();
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
// 每个分区创建一个写入器
partitionWriters[i] =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
// included in the shuffle write time.
writeMetrics.incWriteTime(System.nanoTime() - openStartTime);

while (records.hasNext()) {
final Product2<K, V> record = records.next();
final K key = record._1();
// 根据record的key,将其写入相应的分区文件中
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}

// 写入完成,调用commitAndGet()和close()
for (int i = 0; i < numPartitions; i++) {
final DiskBlockObjectWriter writer = partitionWriters[i];
partitionWriterSegments[i] = writer.commitAndGet();
writer.close();
}

File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
File tmp = Utils.tempFileWith(output);
try {
// 将临时文件拼接成一个最终的文件
partitionLengths = writePartitionedFile(tmp);
// 写入索引文件
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
} finally {
if (tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
/**
* 将分区文件拼接成一个汇总文件
*
* @return 返回一个数值数组,其中数值项代表每个分区数据的字节长度
*/
private long[] writePartitionedFile(File outputFile) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
return lengths;
}

final FileOutputStream out = new FileOutputStream(outputFile, true);
final long writeStartTime = System.nanoTime();
boolean threwException = true;
try {
for (int i = 0; i < numPartitions; i++) { // 按分区ID,将分区文件数据copy到最终的文件
final File file = partitionWriterSegments[i].file();
if (file.exists()) {
final FileInputStream in = new FileInputStream(file);
boolean copyThrewException = true;
try {
lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
copyThrewException = false;
} finally {
Closeables.close(in, copyThrewException);
}
if (!file.delete()) {
logger.error("Unable to delete file for partition {}", i);
}
}
}
threwException = false;
} finally {
Closeables.close(out, threwException);
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
partitionWriters = null;
return lengths;
}

UnsafeShuffleWriter原理

该类底层使用了 sun.misc.Unsafe,顾名思义叫UnsafeShuffleWriter。

在内部,它将数据序列化后提交给ShuffleExternalSorter,ShuffleExternalSorter将数据写入到内存分页当中,并同时将该数据在内存对应的地址信息(内存分页编号+内存的offset)和相应的分区ID进行编码作为数据指针插入到基于内存排序的ShuffleInMemorySorter中。

当发生溢出时,通过ShuffleInMemorySorter基于对数据指针进行排序(只对数据指针的分区ID进行排序)从而达到对内存分页中数据排序的效果,并最终将排序数据按分区ID的顺序依次输出到临时溢出文件中。在通过DiskBlockObjectWriter写入文件时,每个分区进行一次提交,每次提交记录一个FileSegment。

ShuffleExternalSorter会将最后的内存分页也写入到磁盘文件,合并阶段是完全基于这些溢出文件进行的。合并时,针对是否支持快速合并的要求,执行快合并或慢合并(快合并的条件是:a)不开启压缩,或 b)如果开启了压缩,Snappy、LZF、LZ4、ZStd支持快合并)。其中慢合并会涉及到对溢出文件流进行解序列化的操作,合并后再序列化输出到最终shuffle文件,开销比较大。

下图描绘了UnsafeShuffleWriter的实现细节:

TaskMemoryManager

1
2
// 分页表,每个MemoryBlock都有pageNumber属性,其属性值对应相应该MemoryBlock在此分页表的索引位置
private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE];

UnsafeShuffleWriter

写入数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
//写入当前分区的数据
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
// Keep track of success so we know if we encountered an exception
// We do this rather than a standard try/catch/re-throw to handle
// generic throwables.
boolean success = false;
try {
while (records.hasNext()) {
insertRecordIntoSorter(records.next());
}
closeAndWriteOutput();
success = true;
} finally {
...
}
}

MyByteArrayOutputStream扩展自ByteArrayOutputStream,是一个内存字节流;SerializationStream包装了MyByteArrayOutputStream,将record序列化后结果存入底层的MyByteArrayOutputStream;ByteArrayOutputStream公布了getBuf() ,这样就可以获取到底层序列化后的字节数据了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
assert(sorter != null);
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serBuffer.reset();
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); // 序列化
serOutputStream.flush();

final int serializedRecordSize = serBuffer.size();
assert (serializedRecordSize > 0);

sorter.insertRecord(
serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
}

合并文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
void closeAndWriteOutput() throws IOException {
assert(sorter != null);
updatePeakMemoryUsed();
serBuffer = null;
serOutputStream = null;
// 将全部的内存分页写入到磁盘,并释放相应内存
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
final long[] partitionLengths;
final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final File tmp = Utils.tempFileWith(output);
try {
try {
partitionLengths = mergeSpills(spills, tmp);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && ! spill.file.delete()) {
logger.error("Error while deleting spill file {}", spill.file.getPath());
}
}
}
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
} finally {
if (tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
// 非压缩或Snappy、LZF、LZ4、ZStd都支持快速合并
final boolean fastMergeIsSupported = !compressionEnabled ||
CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
return new long[partitioner.numPartitions()];
} else if (spills.length == 1) {
// Here, we don't need to perform any metrics updates because the bytes written to this
// output file would have already been counted as shuffle bytes written.
Files.move(spills[0].file, outputFile);
return spills[0].partitionLengths;
} else {
final long[] partitionLengths;

// 快速合并的前提是:
// 1. 不开启压缩;
// 2. 如果开启了压缩,Snappy、LZF、LZ4、ZStd是可以进行在不解压缩情况下,执行数据拼接
if (fastMergeEnabled && fastMergeIsSupported) {
if (transferToEnabled && !encryptionEnabled) {
logger.debug("Using transferTo-based fast merge");
// 使用NIO进行快速合并
partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
} else {
logger.debug("Using fileStream-based fast merge");
// 使用java文件流进行合并,往往会比mergeSpillsWithTransferTo()要慢
partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
}
} else {
logger.debug("Using slow merge");
// 解压缩溢出文件流,再对输出流进行压缩,性能很差
partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
}

writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
writeMetrics.incBytesWritten(outputFile.length());
return partitionLengths;
}
} catch (IOException e) {
if (outputFile.exists() && !outputFile.delete()) {
logger.error("Unable to delete output file {}", outputFile.getPath());
}
throw e;
}
}

ShuffleExternalSorter

1
2
3
4
5
6
7
8
9
10
11
12
13
// 排序的内存分页,这些分页的内存已经被释放并且其内的数据已经被溢出到磁盘
private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<>();

// 每次溢出由一个SpillInfo来表示,其内有溢出文件的统计信息
private final LinkedList<SpillInfo> spills = new LinkedList<>();


// 当前数据分页的排序器,每次溢出后,进行重置inMemSorter.reset()
@Nullable private ShuffleInMemorySorter inMemSorter;
// 当前数据分页,每次溢出后,置空currentPage=null
@Nullable private MemoryBlock currentPage = null;
// 当前数据分页中的offset,每次溢出后,pageCursor=0
private long pageCursor = -1;

插入record数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId)
throws IOException {

// for tests
assert(inMemSorter != null);
if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
logger.info("Spilling data because number of spilledRecords crossed the threshold " +
numElementsForSpillThreshold);
spill();
}

growPointerArrayIfNecessary();
final int uaoSize = UnsafeAlignedOffset.getUaoSize();
// Need 4 or 8 bytes to store the record length.
final int required = length + uaoSize;
acquireNewPageIfNecessary(required);

assert(currentPage != null);
final Object base = currentPage.getBaseObject();
// 对当前内存分页和pageCursor进行编码,作为当前写入数据的地址
// 注意,该地址,高13位表示内存分页,低51为表示当前页的偏移量
final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
UnsafeAlignedOffset.putSize(base, pageCursor, length);
pageCursor += uaoSize;
// 1. 从序列化数据recordBase的recordOffset处,到base的pageCursor处,copy数据,长度为length,
// 注意这里数据页中留出了uaoSize对齐的空间
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
pageCursor += length;
// 2. 向inMemSorter插入指针记录,用于排序
inMemSorter.insertRecord(recordAddress, partitionId);
}

溢出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/**
* Sort and spill the current records in response to memory pressure.
*/
@Override
public long spill(long size, MemoryConsumer trigger) throws IOException {
if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) {
return 0L;
}

logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
spills.size(),
spills.size() > 1 ? " times" : " time");

writeSortedFile(false);
final long spillSize = freeMemory();
inMemSorter.reset();
// Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
// records. Otherwise, if the task is over allocated memory, then without freeing the memory
// pages, we might not be able to get memory for the pointer array.
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
return spillSize;
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/**
* 对内存中的记录进行排序,然后将排序后的记录写入到磁盘上的文件
*/
private void writeSortedFile(boolean isLastFile) {

final ShuffleWriteMetrics writeMetricsToUse;

if (isLastFile) {
// We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
writeMetricsToUse = writeMetrics;
} else {
// We're spilling, so bytes written should be counted towards spill rather than write.
// Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
// them towards shuffle bytes written.
writeMetricsToUse = new ShuffleWriteMetrics();
}

// 对当前数据分页对应的指针数组进行排序
final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
inMemSorter.getSortedIterator();

// 小数据量写入到DiskBlockObjectWriter性能是极其不高效的,这里使用字节数组充当缓存
final byte[] writeBuffer = new byte[diskWriteBufferSize];

// Because this output will be read during shuffle, its compression codec must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more details.
final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = spilledFileInfo._2();
final TempShuffleBlockId blockId = spilledFileInfo._1();
final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);

// 我们需要一个serializer实例来构造DiskBlockObjectWriter对象,DiskBlockObjectWriter对象封装并包装了文件流,
// 但UnsafeShuffleWriter实际上并没有使用这个serializer,因为下面写入文件时调用的是
// DiskBlockObjectWriter.write(kvBytes: Array[Byte], offs: Int, len: Int),
// 这个方法实际直接将字节数据写入包装后的压缩流,并没有使用序列化流,因为数据分页中的数据已经是经过序列化后的数据了,
// 所以,这里传递了一个虚设的没有实际作用的serializer。
final SerializerInstance ser = DummySerializerInstance.INSTANCE;

final DiskBlockObjectWriter writer =
blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);

int currentPartition = -1;
final int uaoSize = UnsafeAlignedOffset.getUaoSize();
while (sortedRecords.hasNext()) {
sortedRecords.loadNext();
final int partition = sortedRecords.packedRecordPointer.getPartitionId();
assert (partition >= currentPartition);
if (partition != currentPartition) {
// Switch to the new partition
if (currentPartition != -1) {
final FileSegment fileSegment = writer.commitAndGet(); // 每个分区对应一个文件分片
spillInfo.partitionLengths[currentPartition] = fileSegment.length();
}
currentPartition = partition;
}

final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
final Object recordPage = taskMemoryManager.getPage(recordPointer);
final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage);
long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length
while (dataRemaining > 0) {
final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining);
// 将数据从数据分页中copy到数组缓存
Platform.copyMemory(
recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
// 将数组缓存中的数据写入
writer.write(writeBuffer, 0, toTransfer);
recordReadPosition += toTransfer;
dataRemaining -= toTransfer;
}
writer.recordWritten();
}

final FileSegment committedSegment = writer.commitAndGet();
writer.close();
// If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
// then the file might be empty. Note that it might be better to avoid calling
// writeSortedFile() in that case.
if (currentPartition != -1) {
spillInfo.partitionLengths[currentPartition] = committedSegment.length();
spills.add(spillInfo);
}

if (!isLastFile) { // i.e. this is a spill file
// The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
// are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
// relies on its `recordWritten()` method being called in order to trigger periodic updates to
// `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
// counter at a higher-level, then the in-progress metrics for records written and bytes
// written would get out of sync.
//
// When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
// in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
// metrics to the true write metrics here. The reason for performing this copying is so that
// we can avoid reporting spilled bytes as shuffle write bytes.
//
// Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
// Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
// This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten());
taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten());
}
}

ShuffleInMemorySorter

1
2
3
4
5
6
7
8
9
10
11
/**
* 指针数组,
* 1. 其中指针为:记录地址+分区ID经过PackedRecordPointer压缩后的指针;
* 2. 排序操作实际是在该数组上进行的,而不是直接对底层数据进行排序;
* 3. 该数据中只有一部分用于存储指针,剩余空间用作排序时的缓存;
*/
private LongArray array;
/**
* 指针数组中新纪录插入的位置,每插入一条新记录,该指针递增1
*/
private int pos = 0;

插入记录指针

1
2
3
4
5
6
7
public void insertRecord(long recordPointer, int partitionId) {
if (!hasSpaceForAnotherRecord()) {
throw new IllegalStateException("There is no space for new record");
}
array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId));
pos++;
}

PackedRecordPointer.packPointer()压缩后的指针结构为:

排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/**
* 对记录指针数组进行排序,并返回相应迭代器
*/
public ShuffleSorterIterator getSortedIterator() {
int offset = 0;
if (useRadixSort) {
offset = RadixSort.sort(
array, pos,
PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
} else {
MemoryBlock unused = new MemoryBlock(
array.getBaseObject(),
array.getBaseOffset() + pos * 8L,
(array.size() - pos) * 8L);
LongArray buffer = new LongArray(unused);
Sorter<PackedRecordPointer, LongArray> sorter =
new Sorter<>(new ShuffleSortDataFormat(buffer));

sorter.sort(array, 0, pos, SORT_COMPARATOR);
}
return new ShuffleSorterIterator(pos, array, offset);
}

getSortedIterator()中在不开启基数排序(RadixSort)时使用的SortComparator:

1
2
3
4
5
6
7
8
9
private static final class SortComparator implements Comparator<PackedRecordPointer> {
@Override
public int compare(PackedRecordPointer left, PackedRecordPointer right) {
int leftId = left.getPartitionId();
int rightId = right.getPartitionId();
return leftId < rightId ? -1 : (leftId > rightId ? 1 : 0);
}
}
private static final SortComparator SORT_COMPARATOR = new SortComparator();

SortShuffleWriter原理

在内部,它将数据提交给ExternalSorter,如果shuffle依赖需要进行map端合并,那么ExternalSorter将数据插入内存结构PartitionedAppendOnlyMap当中,否则将数据插入内存结构PartitionedPairBuffer。

当发生溢出时,ExternalSorter通过PartitionedAppendOnlyMap或PartitionedPairBuffer对内存集合进行排序并返回排序后数据的迭代器;排序时,首先比较数据的分区,分区ID作为第一排序依据,分区相同,分区内部中记录按record的“键”进行排序;如果shuffle操作是不排序/不聚合的操作,那么只按照分区ID进行排序。在写入溢出文件时,每个batch进行一次提交(由spark.shuffle.spill.batchSize控制,默认10000),每个batch对应一个文件分片(FileSegment),后面ExternalSorter在进行合并溢出文件的时候是以batch为单元进行读取文件的。最终溢出文件是序列化的、压缩的、排序的文件。

在合并溢出文件阶段,ExternalSorter根据是否同时存在溢出数据和内存数据,进行归并排序后输出到最终的shuffle文件中;其间,如果定义了聚合操作,归并后再进行聚合。归并排序的原理,参考此图

缺点:该ShuffleWriter的不足之处是,在合并溢出文件的时候,会先解序列化文件流,归并排序后再把结果序列化输出到最终的文件中,序列化/解序列化的开销是其它ShuffleWriter的一倍。

下图描绘了SortShuffleWriter的实现细节:

SortShuffleWriter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
// 如果不需要map端合并,则aggregator和ordering都传None,
// 因为不需要关心每个分区中key是否是排序的;
// 如果运行的是sortByKey,那么key的排序在reduce端进行
new ExternalSorter[K, V, V](
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
sorter.insertAll(records)

// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = Utils.tempFileWith(output)
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
// 写入索引文件
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
}
}
}

ExternalSorter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// 对于需要聚合的,使用该map;
// 对于不需要聚合的,使用该buffer,注意,这里的C与原始记录的V类型是一致的
@volatile private var map = new PartitionedAppendOnlyMap[K, C]
@volatile private var buffer = new PartitionedPairBuffer[K, C]

//
private val spills = new ArrayBuffer[SpilledFile]
// 溢出数据总量
@volatile private[this] var _memoryBytesSpilled = 0L

// 溢出次数
private[this] var _spillCount = 0


// 分区内部,对于记录按key进行排序,规则为:
// 1.如果定义了ordering,使用定义的排序规则;
// 2.否则,使用基于记录中key的hash值进行排序;
private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
override def compare(a: K, b: K): Int = {
val h1 = if (a == null) 0 else a.hashCode()
val h2 = if (b == null) 0 else b.hashCode()
if (h1 < h2) -1 else if (h1 == h2) 0 else 1
}
})

private def comparator: Option[Comparator[K]] = {
if (ordering.isDefined || aggregator.isDefined) {
Some(keyComparator)
} else {
None
}
}

写入数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined

if (shouldCombine) {
// Combine values in-memory first using our AppendOnlyMap
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (records.hasNext) {
addElementsRead()
kv = records.next()
map.changeValue((getPartition(kv._1), kv._1), update) // map更新插入
maybeSpillCollection(usingMap = true)
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])// buffer插入
maybeSpillCollection(usingMap = false)
}
}
}

溢出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
private def maybeSpillCollection(usingMap: Boolean): Unit = {
var estimatedSize = 0L
if (usingMap) {
// 评估当前内存集合的数据量大小
estimatedSize = map.estimateSize()
if (maybeSpill(map, estimatedSize)) {
map = new PartitionedAppendOnlyMap[K, C]
}
} else {
// 评估当前内存集合的数据量大小
estimatedSize = buffer.estimateSize()
if (maybeSpill(buffer, estimatedSize)) {
buffer = new PartitionedPairBuffer[K, C]
}
}

if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/**
* Spills the current in-memory collection to disk if needed. Attempts to acquire more
* memory before spilling.
*
* @param collection collection to spill to disk
* @param currentMemory estimated size of the collection in bytes
* @return true if `collection` was spilled to disk; false otherwise
*/
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
var shouldSpill = false
// 当前内存集合记录数是否为32的倍数,即,每添加32个record到内存集合判断一次是否是需要进行溢出
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// 从执行内存中申请2倍于当前内存集合的空间
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted = acquireMemory(amountToRequest)
myMemoryThreshold += granted
// 如果没有申请到足够的内存,则将当前的集合溢出到磁盘
shouldSpill = currentMemory >= myMemoryThreshold
}
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
// Actually spill
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
spill(collection)
_elementsRead = 0
_memoryBytesSpilled += currentMemory
releaseMemory()
}
shouldSpill
}

1
2
3
4
5
6
7
8
/**
* 将内存集合溢出到排序文件中
*/
override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
spills += spillFile
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
: SpilledFile = {
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more context.
val (blockId, file) = diskBlockManager.createTempShuffleBlock()

// These variables are reset after each flush
var objectsWritten: Long = 0
val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
val writer: DiskBlockObjectWriter =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)

// List of batch sizes (bytes) in the order they are written to disk
val batchSizes = new ArrayBuffer[Long]

// How many elements we have in each partition
val elementsPerPartition = new Array[Long](numPartitions)

// Flush the disk writer's contents to disk, and update relevant variables.
// The writer is committed at the end of this process.
def flush(): Unit = {
val segment = writer.commitAndGet()
batchSizes += segment.length
_diskBytesSpilled += segment.length
objectsWritten = 0
}

var success = false
try {
while (inMemoryIterator.hasNext) {
val partitionId = inMemoryIterator.nextPartition()
require(partitionId >= 0 && partitionId < numPartitions,
s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
inMemoryIterator.writeNext(writer)
elementsPerPartition(partitionId) += 1
objectsWritten += 1

if (objectsWritten == serializerBatchSize) {
flush()
}
}
if (objectsWritten > 0) {
flush()
} else {
writer.revertPartialWritesAndClose()
}
success = true
} finally {
if (success) {
writer.close()
} else {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
writer.revertPartialWritesAndClose()
if (file.exists()) {
if (!file.delete()) {
logWarning(s"Error deleting ${file}")
}
}
}
}

SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
}

合并文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def writePartitionedFile(
blockId: BlockId,
outputFile: File): Array[Long] = {

// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics().shuffleWriteMetrics)

if (spills.isEmpty) {// 只有内存数据,没有溢出数据
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
val partitionId = it.nextPartition()
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
val segment = writer.commitAndGet()
lengths(partitionId) = segment.length
}
} else {// 有溢出数据,需要进行溢出数据和内存中的数据进行归并排序
// We must perform merge-sort; get an iterator by partition and write everything directly.
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
for (elem <- elements) { // 将一个分区中的数据按序写入到文件流
writer.write(elem._1, elem._2)
}
val segment = writer.commitAndGet() // 每个分区进行一次提交
lengths(id) = segment.length
}
}
}

writer.close()
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

lengths
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/**
* Return an iterator over all the data written to this object, grouped by partition and
* aggregated by the requested aggregator. For each partition we then have an iterator over its
* contents, and these are expected to be accessed in order (you can't "skip ahead" to one
* partition without reading the previous one). Guaranteed to return a key-value pair for each
* partition, in order of partition ID.
*
* For now, we just merge all the spilled files in once pass, but this can be modified to
* support hierarchical merging.
* Exposed for testing.
*/
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
if (spills.isEmpty) {
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
// we don't even need to sort by anything other than partition ID
if (!ordering.isDefined) {
// The user hasn't requested sorted keys, so only sort by partition ID, not key
groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
} else {
// We do need to sort by both partition ID and key
groupByPartition(destructiveIterator(
collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
}
} else {
// Merge spilled and in-memory data
merge(spills, destructiveIterator(
collection.partitionedDestructiveSortedIterator(comparator)))
}
}

/**
* 返回一个破环性的迭代器来遍历该map中的记录。
* 如果在没有足够内存时,该迭代器会被迫溢出到磁盘来释放内存,从而它返回的是来自基于磁盘map的KV对。
*/
def destructiveIterator(memoryIterator: Iterator[((Int, K), C)]): Iterator[((Int, K), C)] = {
if (isShuffleSort) {
memoryIterator
} else {
readingIterator = new SpillableIterator(memoryIterator)
readingIterator
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
/**
* 将所有溢出的文件和内存中的数据进行合并,返回一个迭代器,该迭代器的数据项为:(分区ID,分区内记录迭代器)
*
*/
private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] = {
// 每个溢出文件创建一个SpillReader
val readers = spills.map(new SpillReader(_))
val inMemBuffered = inMemory.buffered
(0 until numPartitions).iterator.map { p =>
// 包装为IteratorForPartition
val inMemIterator = new IteratorForPartition(p, inMemBuffered)
// 将溢出文件的分区迭代器和内存的分区迭代器合并
val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
if (aggregator.isDefined) { // 如果需要聚合,对当前分区执行聚合
// Perform partial aggregation across partitions
(p, mergeWithAggregation(
iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
} else if (ordering.isDefined) { // 如果不需要聚合,但定义了排序,对当前分区执行归并排序
// No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
// sort the elements without trying to merge them
(p, mergeSort(iterators, ordering.get))
} else { // 否则,直接对各迭代器进行合并作为当前分区的数据
(p, iterators.iterator.flatten)
}
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
 /**
* Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys.
*/
private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
: Iterator[Product2[K, C]] =
{
val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
type Iter = BufferedIterator[Product2[K, C]]
val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
// 对元素的key进行排序,注意这里对key进行了升序的处理
override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1)
})
heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true
new Iterator[Product2[K, C]] {
override def hasNext: Boolean = !heap.isEmpty

override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val firstBuf = heap.dequeue() // key最小的记录所在的迭代器出队
val firstPair = firstBuf.next() // 因为使用的是BufferedIterator,next()返回head的值,并不向后推进迭代器
if (firstBuf.hasNext) { // 如果出队的迭代器还有记录,将其入队
heap.enqueue(firstBuf)
}
firstPair
}
}
}

归并排序图示:


SpillableIterator

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
private[this] class SpillableIterator(var upstream: Iterator[((Int, K), C)])
extends Iterator[((Int, K), C)] {

private val SPILL_LOCK = new Object()

private var nextUpstream: Iterator[((Int, K), C)] = null

private var cur: ((Int, K), C) = readNext()

private var hasSpilled: Boolean = false

def spill(): Boolean = SPILL_LOCK.synchronized {
if (hasSpilled) {
false
} else {
val inMemoryIterator = new WritablePartitionedIterator {
private[this] var cur = if (upstream.hasNext) upstream.next() else null

def writeNext(writer: DiskBlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (upstream.hasNext) upstream.next() else null
}

def hasNext(): Boolean = cur != null

def nextPartition(): Int = cur._1._1
}
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
forceSpillFiles += spillFile
val spillReader = new SpillReader(spillFile)
nextUpstream = (0 until numPartitions).iterator.flatMap { p =>
val iterator = spillReader.readNextPartition()
iterator.map(cur => ((p, cur._1), cur._2))
}
hasSpilled = true
true
}
}

def readNext(): ((Int, K), C) = SPILL_LOCK.synchronized {
if (nextUpstream != null) {
upstream = nextUpstream
nextUpstream = null
}
if (upstream.hasNext) {
upstream.next()
} else {
null
}
}

override def hasNext(): Boolean = cur != null

override def next(): ((Int, K), C) = {
val r = cur
cur = readNext()
r
}
}

SpillReader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/** Construct a stream that only reads from the next batch */
def nextBatchStream(): DeserializationStream = {
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
// we're still in a valid batch.
if (batchId < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
deserializeStream = null
fileStream = null
}

val start = batchOffsets(batchId) // 每个batch的起始字节
fileStream = new FileInputStream(spill.file)
fileStream.getChannel.position(start)
batchId += 1

val end = batchOffsets(batchId) // 每个batch的结束字节

assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

// 从文件流中读取(end - start)个字节
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))

val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream)
// 返回当前batch数据的解序列化流
serInstance.deserializeStream(wrappedStream)
} else {
// No more batches left
cleanup()
null
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] { // 创建该分区的迭代器
val myPartition = nextPartitionToRead
nextPartitionToRead += 1

// ExternalSorter.mergeSort()每次出队时,会先调用next(),拿到缓存的nextItem,
// 然后调用hasNext()读取下一个元素,并设置给nextItem
override def hasNext: Boolean = {
if (nextItem == null) {
nextItem = readNextItem()
if (nextItem == null) {
return false
}
}
assert(lastPartitionId >= myPartition)
// Check that we're still in the right partition; note that readNextItem will have returned
// null at EOF above so we would've returned false there
lastPartitionId == myPartition
}

// ExternalSorter.mergeSort()每次出队时,会调用此方法进行判断
override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val item = nextItem // 直接返回nextItem作为下一个元素
nextItem = null // 置空,这样当调用hasNext()时就会触发readNextItem()读取下一个元素
item
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/**
* 从解序列化流中读取下一个(K,C)对,如果当前batch被消费完了,那么触发下一个batch的读取
*/
private def readNextItem(): (K, C) = {
if (finished || deserializeStream == null) {
return null
}
val k = deserializeStream.readKey().asInstanceOf[K]
val c = deserializeStream.readValue().asInstanceOf[C]
lastPartitionId = partitionId
// Start reading the next batch if we're done with this one
indexInBatch += 1
if (indexInBatch == serializerBatchSize) {
indexInBatch = 0
deserializeStream = nextBatchStream()
}
// Update the partition location of the element we're reading
indexInPartition += 1
skipToNextPartition()
// If we've finished reading the last partition, remember that we're done
if (partitionId == numPartitions) {
finished = true
if (deserializeStream != null) {
deserializeStream.close()
}
}
(k, c)
}

WritablePartitionedPairCollection

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/**
Iterate through the data and write out the elements instead of returning them. Records are
* returned in order of their partition ID and then the given comparator.
* This may destroy the underlying collection.
*
* 其中的keyComparator为ExternalSorter中的定义的
*/
def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
: WritablePartitionedIterator = {
val it = partitionedDestructiveSortedIterator(keyComparator)
new WritablePartitionedIterator {
private[this] var cur = if (it.hasNext) it.next() else null

def writeNext(writer: DiskBlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}

def hasNext(): Boolean = cur != null

def nextPartition(): Int = cur._1._1
}
}

工具方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private[spark] object WritablePartitionedPairCollection {
/**
* A comparator for (Int, K) pairs that orders them by only their partition ID.
*/
def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] {
override def compare(a: (Int, K), b: (Int, K)): Int = {
a._1 - b._1
}
}

/**
* A comparator for (Int, K) pairs that orders them both by their partition ID and a key ordering.
*/
def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = {
new Comparator[(Int, K)] {
override def compare(a: (Int, K), b: (Int, K)): Int = {
val partitionDiff = a._1 - b._1
if (partitionDiff != 0) { // 对内存集合中记录排序时,首先比较其分区,分区ID作为第一排序依据
partitionDiff
} else {
keyComparator.compare(a._2, b._2) // 分区相同,分区内部中记录按“键”进行排序
}
}
}
}
}

PartitionedAppendOnlyMap

该类继承自SizeTrackingAppendOnlyMap[(Int, K), V],(Int, K)为底层map的键类型(其中int为记录的分区ID),V为底层map的值类型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
private[spark] class PartitionedAppendOnlyMap[K, V]
extends SizeTrackingAppendOnlyMap[(Int, K), V] with WritablePartitionedPairCollection[K, V] {

def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
: Iterator[((Int, K), V)] = {
// 对内存集合中记录排序时,首先比较其分区,分区ID作为第一排序依据,分区相同,分区内部中记录按“键”进行排序;
// 如果不需要shuffle操作是不排序/不聚合的操作,那么只按照分区ID进行排序;
val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
destructiveSortedIterator(comparator)
}

def insert(partition: Int, key: K, value: V): Unit = {
update((partition, key), value)
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
/**
* 返回该map结构的排序的迭代器。实现原理是破坏底层为了支持map存储的数组的数据结构,将将底层数据中的KV数据移到数组前部,
* 然后,对重排后的数组进行排序
*/
def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
destroyed = true
// 将底层数据中的KV数据移到数组前部,这样就破坏了底层数据对于map的支持,所以该方法的名称为“破坏性排序的迭代器”
var keyIndex, newIndex = 0
while (keyIndex < capacity) {
if (data(2 * keyIndex) != null) {
data(2 * newIndex) = data(2 * keyIndex)
data(2 * newIndex + 1) = data(2 * keyIndex + 1)
newIndex += 1
}
keyIndex += 1
}
assert(curSize == newIndex + (if (haveNullValue) 1 else 0))

// 内部使用TimSort排序
new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, newIndex, keyComparator)

new Iterator[(K, V)] {
var i = 0
var nullValueReady = haveNullValue
def hasNext: Boolean = (i < newIndex || nullValueReady)
def next(): (K, V) = {
if (nullValueReady) {
nullValueReady = false
(null.asInstanceOf[K], nullValue)
} else {
val item = (data(2 * i).asInstanceOf[K], data(2 * i + 1).asInstanceOf[V])
i += 1
item
}
}
}
}

我们来举例看一下这里的“破坏”底层数据前后对照:将有效数据移动到数组的前部

PartitionedPairBuffer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64)
extends WritablePartitionedPairCollection[K, V] with SizeTracker {

private var capacity = initialCapacity
private var curSize = 0
private var data = new Array[AnyRef](2 * initialCapacity)


/** Iterate through the data in a given order. For this class this is not really destructive. */
override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
: Iterator[((Int, K), V)] = {
// 对内存集合中记录排序时,首先比较其分区,分区ID作为第一排序依据,分区相同,分区内部中记录按“键”进行排序;
// 如果不需要shuffle操作是不排序/不聚合的操作,那么只按照分区ID进行排序;
val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
// 内部使用TimSort排序
new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator)
iterator
}

private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] {
var pos = 0

override def hasNext: Boolean = pos < curSize

override def next(): ((Int, K), V) = {
if (!hasNext) {
throw new NoSuchElementException
}
val pair = (data(2 * pos).asInstanceOf[(Int, K)], data(2 * pos + 1).asInstanceOf[V])
pos += 1
pair
}
}
}

NOTEs

本文以Spark 2.4.3为基础。