[SPARK-17714][CORE][TEST-MAVEN][TEST-HADOOP2.6] Avoid using ExecutorClassLoader to load Netty generated classes

## What changes were proposed in this pull request?

Netty's `MessageToMessageEncoder` uses [Javassist](91a0bdc17a/common/src/main/java/io/netty/util/internal/JavassistTypeParameterMatcherGenerator.java (L62)) to generate a matcher class and the implementation calls `Class.forName` to check if this class is already generated. If `MessageEncoder` or `MessageDecoder` is created in `ExecutorClassLoader.findClass`, it will cause `ClassCircularityError`. This is because loading this Netty generated class will call `ExecutorClassLoader.findClass` to search this class, and `ExecutorClassLoader` will try to use RPC to load it and cause to load the non-exist matcher class again. JVM will report `ClassCircularityError` to prevent such infinite recursion.

##### Why it only happens in Maven builds

It's because Maven and SBT have different class loader tree. The Maven build will set a URLClassLoader as the current context class loader to run the tests and expose this issue. The class loader tree is as following:

```
bootstrap class loader ------ ... ----- REPL class loader ---- ExecutorClassLoader
|
|
URLClasssLoader
```

The SBT build uses the bootstrap class loader directly and `ReplSuite.test("propagation of local properties")` is the first test in ReplSuite, which happens to load `io/netty/util/internal/__matchers__/org/apache/spark/network/protocol/MessageMatcher` into the bootstrap class loader (Note: in maven build, it's loaded into URLClasssLoader so it cannot be found in ExecutorClassLoader). This issue can be reproduced in SBT as well. Here are the produce steps:
- Enable `hadoop.caller.context.enabled`.
- Replace `Class.forName` with `Utils.classForName` in `object CallerContext`.
- Ignore `ReplSuite.test("propagation of local properties")`.
- Run `ReplSuite` using SBT.

This PR just creates a singleton MessageEncoder and MessageDecoder and makes sure they are created before switching to ExecutorClassLoader. TransportContext will be created when creating RpcEnv and that happens before creating ExecutorClassLoader.

## How was this patch tested?

Jenkins

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16859 from zsxwing/SPARK-17714.
This commit is contained in:
Shixiong Zhu 2017-02-13 12:03:36 -08:00
parent 3dbff9be06
commit 905fdf0c24
6 changed files with 38 additions and 27 deletions

View file

@ -62,8 +62,20 @@ public class TransportContext {
private final RpcHandler rpcHandler; private final RpcHandler rpcHandler;
private final boolean closeIdleConnections; private final boolean closeIdleConnections;
private final MessageEncoder encoder; /**
private final MessageDecoder decoder; * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created
* before switching the current context class loader to ExecutorClassLoader.
*
* Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the
* implementation calls "Class.forName" to check if this calls is already generated. If the
* following two objects are created in "ExecutorClassLoader.findClass", it will cause
* "ClassCircularityError". This is because loading this Netty generated class will call
* "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use
* RPC to load it and cause to load the non-exist matcher class again. JVM will report
* `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714)
*/
private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
private static final MessageDecoder DECODER = MessageDecoder.INSTANCE;
public TransportContext(TransportConf conf, RpcHandler rpcHandler) { public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
this(conf, rpcHandler, false); this(conf, rpcHandler, false);
@ -75,8 +87,6 @@ public class TransportContext {
boolean closeIdleConnections) { boolean closeIdleConnections) {
this.conf = conf; this.conf = conf;
this.rpcHandler = rpcHandler; this.rpcHandler = rpcHandler;
this.encoder = new MessageEncoder();
this.decoder = new MessageDecoder();
this.closeIdleConnections = closeIdleConnections; this.closeIdleConnections = closeIdleConnections;
} }
@ -135,9 +145,9 @@ public class TransportContext {
try { try {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline() channel.pipeline()
.addLast("encoder", encoder) .addLast("encoder", ENCODER)
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast("decoder", decoder) .addLast("decoder", DECODER)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
// would require more logic to guarantee if this were not part of the same event loop. // would require more logic to guarantee if this were not part of the same event loop.

View file

@ -35,6 +35,10 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
public static final MessageDecoder INSTANCE = new MessageDecoder();
private MessageDecoder() {}
@Override @Override
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) { public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
Message.Type msgType = Message.Type.decode(in); Message.Type msgType = Message.Type.decode(in);

View file

@ -35,6 +35,10 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
public static final MessageEncoder INSTANCE = new MessageEncoder();
private MessageEncoder() {}
/*** /***
* Encodes a Message by invoking its encode() method. For non-data messages, we will add one * Encodes a Message by invoking its encode() method. For non-data messages, we will add one
* ByteBuf to 'out' containing the total frame length, the message type, and the message itself. * ByteBuf to 'out' containing the total frame length, the message type, and the message itself.

View file

@ -18,7 +18,7 @@
package org.apache.spark.network.server; package org.apache.spark.network.server;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent; import io.netty.handler.timeout.IdleStateEvent;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -26,7 +26,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.RequestMessage;
import org.apache.spark.network.protocol.ResponseMessage; import org.apache.spark.network.protocol.ResponseMessage;
import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
@ -48,7 +47,7 @@ import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
* on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
* timeout if the client is continuously sending but getting no responses, for simplicity. * timeout if the client is continuously sending but getting no responses, for simplicity.
*/ */
public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> { public class TransportChannelHandler extends ChannelInboundHandlerAdapter {
private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
private final TransportClient client; private final TransportClient client;
@ -114,11 +113,13 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
} }
@Override @Override
public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception {
if (request instanceof RequestMessage) { if (request instanceof RequestMessage) {
requestHandler.handle((RequestMessage) request); requestHandler.handle((RequestMessage) request);
} else { } else if (request instanceof ResponseMessage) {
responseHandler.handle((ResponseMessage) request); responseHandler.handle((ResponseMessage) request);
} else {
ctx.fireChannelRead(request);
} }
} }

View file

@ -49,11 +49,11 @@ import org.apache.spark.network.util.NettyUtils;
public class ProtocolSuite { public class ProtocolSuite {
private void testServerToClient(Message msg) { private void testServerToClient(Message msg) {
EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
new MessageEncoder()); MessageEncoder.INSTANCE);
serverChannel.writeOutbound(msg); serverChannel.writeOutbound(msg);
EmbeddedChannel clientChannel = new EmbeddedChannel( EmbeddedChannel clientChannel = new EmbeddedChannel(
NettyUtils.createFrameDecoder(), new MessageDecoder()); NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
while (!serverChannel.outboundMessages().isEmpty()) { while (!serverChannel.outboundMessages().isEmpty()) {
clientChannel.writeInbound(serverChannel.readOutbound()); clientChannel.writeInbound(serverChannel.readOutbound());
@ -65,11 +65,11 @@ public class ProtocolSuite {
private void testClientToServer(Message msg) { private void testClientToServer(Message msg) {
EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
new MessageEncoder()); MessageEncoder.INSTANCE);
clientChannel.writeOutbound(msg); clientChannel.writeOutbound(msg);
EmbeddedChannel serverChannel = new EmbeddedChannel( EmbeddedChannel serverChannel = new EmbeddedChannel(
NettyUtils.createFrameDecoder(), new MessageDecoder()); NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
while (!clientChannel.outboundMessages().isEmpty()) { while (!clientChannel.outboundMessages().isEmpty()) {
serverChannel.writeInbound(clientChannel.readOutbound()); serverChannel.writeInbound(clientChannel.readOutbound());

View file

@ -2608,12 +2608,8 @@ private[util] object CallerContext extends Logging {
val callerContextSupported: Boolean = { val callerContextSupported: Boolean = {
SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && { SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
try { try {
// `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in Utils.classForName("org.apache.hadoop.ipc.CallerContext")
// master Maven build, so do not use it before resolving SPARK-17714. Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
// scalastyle:off classforname
Class.forName("org.apache.hadoop.ipc.CallerContext")
Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
// scalastyle:on classforname
true true
} catch { } catch {
case _: ClassNotFoundException => case _: ClassNotFoundException =>
@ -2688,12 +2684,8 @@ private[spark] class CallerContext(
def setCurrentContext(): Unit = { def setCurrentContext(): Unit = {
if (CallerContext.callerContextSupported) { if (CallerContext.callerContextSupported) {
try { try {
// `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext")
// master Maven build, so do not use it before resolving SPARK-17714. val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
// scalastyle:off classforname
val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext")
val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
// scalastyle:on classforname
val builderInst = builder.getConstructor(classOf[String]).newInstance(context) val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
val hdfsContext = builder.getMethod("build").invoke(builderInst) val hdfsContext = builder.getMethod("build").invoke(builderInst)
callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)