Spark RPC原理

Spark中的RPC整体上来讲分为两层:分发层和传输层。本文中,我们将对Spark RPC框架进行详细阐述。

类图

消息分发

上图中描述了Dispatcher内部消息的处理过程:

  1. 接收注册Endpoint的请求,将RpcEndpoint进一步封装为EndpointData后注册到Dispatcher的Map结构中(其中,key为Endpoint名称),并同时将该EndpointData添加到EndpointData消息队列,以处理其内部Inbox中默认的OnStart消息;
  2. 在发布消息时,根据请求的endpointName,在注册表中检索相应的EndpointData,将要发送的InboxMessage消息添加到Inbox中,并同时将该EndpointData添加到EndpointData消息队列
  3. MessageLoop线程池从EndpointData消息队列进行消费,每个MessageLoop线程一次处理一个EndpointData,从其Inbox中一次拉取一个消息对象,根据该消息对象的类型调用RpcEndpoint中相应的方法进行处理;

RpcEndpoint

首先,我们来理解一下何为EndPointEndPoint可以理解为RPC调用的最终方法载体,它内部定义了响应InboxMessage消息所需的不同处理方法。

RpcEndpointRef

RpcEndpointRef是RpcEndpoint的引用,它是关联了地址信息并拥有具体名称的实体,最主要的它内部定义了供发送消息的具体接口,其中:

  • send()方法发送的消息,最终由RpcEndpoint.receive()处理;
  • ask()/askSync()方法发送的消息,最终由RpcEndpoint.receiveAndReply()处理;

NettyRpcEndpointRef

(待)

Inbox

1
2
3
4
5
6
7
8
9
10
11
12
13
private[netty] class Inbox(
val endpointRef: NettyRpcEndpointRef,
val endpoint: RpcEndpoint)
extends Logging {

@GuardedBy("this")
protected val messages = new java.util.LinkedList[InboxMessage]()

// OnStart should be the first message to process
inbox.synchronized {
messages.add(OnStart) // Inbox实例化时默认添加了OnStart消息
}
}

process()中定义了针对不同类型的消息调用endpoint的相应方法;因为每次注册新的Endpoint,RpcEndpointData中的Inbox都会默认添加一个OnStart消息,相应的endpoint.onStart()也会第一时间得到处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
inbox.synchronized {
if (!enableConcurrent && numActiveThreads != 0) {
return
}
message = messages.poll()
if (message != null) {
numActiveThreads += 1
} else {
return
}
}
while (true) {
safelyCall(endpoint) {
message match {
case RpcMessage(_sender, content, context) =>
try {
endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
} catch {
case e: Throwable =>
context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}

case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})

case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
if (!stopped) {
enableConcurrent = true
}
}
}

case OnStop =>
val activeThreads = inbox.synchronized { inbox.numActiveThreads }
assert(activeThreads == 1,
s"There should be only a single active thread but found $activeThreads threads.")
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
assert(isEmpty, "OnStop should be the last message")

case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)

case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)

case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
}

inbox.synchronized {
// "enableConcurrent" will be set to false after `onStop` is called, so we should check it
// every time.
if (!enableConcurrent && numActiveThreads != 1) {
// If we are not the only one worker, exit
numActiveThreads -= 1
return
}
message = messages.poll()
if (message == null) {
numActiveThreads -= 1
return
}
}
}
}

RpcEndpointData

Dispatcher

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging {

private val endpoints: ConcurrentMap[String, EndpointData] =
new ConcurrentHashMap[String, EndpointData]
private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =
new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]

// Track the receivers whose inboxes may contain messages.
private val receivers = new LinkedBlockingQueue[EndpointData]

/** Thread pool used for dispatching messages. */
private val threadpool: ThreadPoolExecutor = {
val availableCores =
if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
math.max(2, availableCores))
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
}
pool
}
}

其中:

  • endpoints中保存了注册过的RpcEndpointData信息;
  • endpointRefs中保存了注册过的RpcEndpoint与RpcEndpointRef的映射关系;
  • receivers使用了阻塞队列来实现,用来跟踪发送的消息实体;

注册Endpoint

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
val addr = RpcEndpointAddress(nettyEnv.address, name)
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
synchronized {
if (stopped) {
throw new IllegalStateException("RpcEnv has been stopped")
}
if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
}
val data = endpoints.get(name)
endpointRefs.put(data.endpoint, data.ref)
receivers.offer(data) // for the OnStart message
}
endpointRef
}

发送消息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
private def postMessage(
endpointName: String,
message: InboxMessage,
callbackIfStopped: (Exception) => Unit): Unit = {
val error = synchronized {
val data = endpoints.get(endpointName)
if (stopped) {
Some(new RpcEnvStoppedException())
} else if (data == null) {
Some(new SparkException(s"Could not find $endpointName."))
} else {
data.inbox.post(message) //向相应的RpcEndpointData的收件箱添加消息
receivers.offer(data) //同时,向receivers消息队列添加消息,以便MessageLoop线程处理消息
None
}
}
// We don't need to call `onStop` in the `synchronized` block
error.foreach(callbackIfStopped)
}

处理RpcEndpointData消息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
val data = receivers.take()
if (data == PoisonPill) {
// Put PoisonPill back so that other MessageLoops can see it.
receivers.offer(PoisonPill)
return
}
data.inbox.process(Dispatcher.this)
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
... ...
}
}
}

消息传输

(待)

TODO

NOTEs

本文以Spark 2.4.3为基础。