Spark Shuffle Reader

BlockStoreShuffleReader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
override def read(): Iterator[Product2[K, C]] = {
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}

// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())

// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}

// Sort the output if there is a sort ordering defined.
val resultIter = dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
// Use completion callback to stop sorter if task was finished/cancelled.
context.addTaskCompletionListener[Unit](_ => {
sorter.stop()
})
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}

resultIter match {
case _: InterruptibleIterator[Product2[K, C]] => resultIter
case _ =>
// Use another interruptible iterator here to support task cancellation as aggregator
// or(and) sorter may have consumed previous interruptible iterator.
new InterruptibleIterator[Product2[K, C]](context, resultIter)
}
}

ShuffleBlockFetcherIterator

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
+ ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)

// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]

for ((address, blockInfos) <- blocksByAddress) {
if (address.executorId == blockManager.blockManagerId.executorId) {
blockInfos.find(_._2 <= 0) match {
case Some((blockId, size)) if size < 0 =>
throw new BlockException(blockId, "Negative block size " + size)
case Some((blockId, size)) if size == 0 =>
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
case None => // do nothing.
}
localBlocks ++= blockInfos.map(_._1)
numBlocksToFetch += localBlocks.size
} else {
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
} else if (size == 0) {
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
} else {
curBlocks += ((blockId, size))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
}
if (curRequestSize >= targetRequestSize ||
curBlocks.size >= maxBlocksInFlightPerAddress) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
logDebug(s"Creating fetch request of $curRequestSize at $address "
+ s"with ${curBlocks.size} blocks")
curBlocks = new ArrayBuffer[(BlockId, Long)]
curRequestSize = 0
}
}
// Add in the final request
if (curBlocks.nonEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
logInfo(s"Getting $numBlocksToFetch non-empty blocks including ${localBlocks.size}" +
s" local blocks and ${remoteBlocks.size} remote blocks")
remoteRequests
}

NOTEs

本文以Spark 2.4.3为基础。