[SPARK-17714][CORE][TEST-MAVEN][TEST-HADOOP2.6] Avoid using ExecutorClassLoader to load Netty generated classes
## What changes were proposed in this pull request?
Netty's `MessageToMessageEncoder` uses [Javassist](91a0bdc17a/common/src/main/java/io/netty/util/internal/JavassistTypeParameterMatcherGenerator.java (L62)
) to generate a matcher class and the implementation calls `Class.forName` to check if this class is already generated. If `MessageEncoder` or `MessageDecoder` is created in `ExecutorClassLoader.findClass`, it will cause `ClassCircularityError`. This is because loading this Netty generated class will call `ExecutorClassLoader.findClass` to search this class, and `ExecutorClassLoader` will try to use RPC to load it and cause to load the non-exist matcher class again. JVM will report `ClassCircularityError` to prevent such infinite recursion.
##### Why it only happens in Maven builds
It's because Maven and SBT have different class loader tree. The Maven build will set a URLClassLoader as the current context class loader to run the tests and expose this issue. The class loader tree is as following:
```
bootstrap class loader ------ ... ----- REPL class loader ---- ExecutorClassLoader
|
|
URLClasssLoader
```
The SBT build uses the bootstrap class loader directly and `ReplSuite.test("propagation of local properties")` is the first test in ReplSuite, which happens to load `io/netty/util/internal/__matchers__/org/apache/spark/network/protocol/MessageMatcher` into the bootstrap class loader (Note: in maven build, it's loaded into URLClasssLoader so it cannot be found in ExecutorClassLoader). This issue can be reproduced in SBT as well. Here are the produce steps:
- Enable `hadoop.caller.context.enabled`.
- Replace `Class.forName` with `Utils.classForName` in `object CallerContext`.
- Ignore `ReplSuite.test("propagation of local properties")`.
- Run `ReplSuite` using SBT.
This PR just creates a singleton MessageEncoder and MessageDecoder and makes sure they are created before switching to ExecutorClassLoader. TransportContext will be created when creating RpcEnv and that happens before creating ExecutorClassLoader.
## How was this patch tested?
Jenkins
Author: Shixiong Zhu <shixiong@databricks.com>
Closes #16859 from zsxwing/SPARK-17714.
This commit is contained in:
parent
3dbff9be06
commit
905fdf0c24
|
@ -62,8 +62,20 @@ public class TransportContext {
|
||||||
private final RpcHandler rpcHandler;
|
private final RpcHandler rpcHandler;
|
||||||
private final boolean closeIdleConnections;
|
private final boolean closeIdleConnections;
|
||||||
|
|
||||||
private final MessageEncoder encoder;
|
/**
|
||||||
private final MessageDecoder decoder;
|
* Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created
|
||||||
|
* before switching the current context class loader to ExecutorClassLoader.
|
||||||
|
*
|
||||||
|
* Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the
|
||||||
|
* implementation calls "Class.forName" to check if this calls is already generated. If the
|
||||||
|
* following two objects are created in "ExecutorClassLoader.findClass", it will cause
|
||||||
|
* "ClassCircularityError". This is because loading this Netty generated class will call
|
||||||
|
* "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use
|
||||||
|
* RPC to load it and cause to load the non-exist matcher class again. JVM will report
|
||||||
|
* `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714)
|
||||||
|
*/
|
||||||
|
private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
|
||||||
|
private static final MessageDecoder DECODER = MessageDecoder.INSTANCE;
|
||||||
|
|
||||||
public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
|
public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
|
||||||
this(conf, rpcHandler, false);
|
this(conf, rpcHandler, false);
|
||||||
|
@ -75,8 +87,6 @@ public class TransportContext {
|
||||||
boolean closeIdleConnections) {
|
boolean closeIdleConnections) {
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
this.rpcHandler = rpcHandler;
|
this.rpcHandler = rpcHandler;
|
||||||
this.encoder = new MessageEncoder();
|
|
||||||
this.decoder = new MessageDecoder();
|
|
||||||
this.closeIdleConnections = closeIdleConnections;
|
this.closeIdleConnections = closeIdleConnections;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,9 +145,9 @@ public class TransportContext {
|
||||||
try {
|
try {
|
||||||
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
|
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
|
||||||
channel.pipeline()
|
channel.pipeline()
|
||||||
.addLast("encoder", encoder)
|
.addLast("encoder", ENCODER)
|
||||||
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
|
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
|
||||||
.addLast("decoder", decoder)
|
.addLast("decoder", DECODER)
|
||||||
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
|
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
|
||||||
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
|
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
|
||||||
// would require more logic to guarantee if this were not part of the same event loop.
|
// would require more logic to guarantee if this were not part of the same event loop.
|
||||||
|
|
|
@ -35,6 +35,10 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
|
private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
|
||||||
|
|
||||||
|
public static final MessageDecoder INSTANCE = new MessageDecoder();
|
||||||
|
|
||||||
|
private MessageDecoder() {}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
|
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
|
||||||
Message.Type msgType = Message.Type.decode(in);
|
Message.Type msgType = Message.Type.decode(in);
|
||||||
|
|
|
@ -35,6 +35,10 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
|
private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
|
||||||
|
|
||||||
|
public static final MessageEncoder INSTANCE = new MessageEncoder();
|
||||||
|
|
||||||
|
private MessageEncoder() {}
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* Encodes a Message by invoking its encode() method. For non-data messages, we will add one
|
* Encodes a Message by invoking its encode() method. For non-data messages, we will add one
|
||||||
* ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
|
* ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
package org.apache.spark.network.server;
|
package org.apache.spark.network.server;
|
||||||
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.SimpleChannelInboundHandler;
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
import io.netty.handler.timeout.IdleState;
|
import io.netty.handler.timeout.IdleState;
|
||||||
import io.netty.handler.timeout.IdleStateEvent;
|
import io.netty.handler.timeout.IdleStateEvent;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
@ -26,7 +26,6 @@ import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import org.apache.spark.network.client.TransportClient;
|
import org.apache.spark.network.client.TransportClient;
|
||||||
import org.apache.spark.network.client.TransportResponseHandler;
|
import org.apache.spark.network.client.TransportResponseHandler;
|
||||||
import org.apache.spark.network.protocol.Message;
|
|
||||||
import org.apache.spark.network.protocol.RequestMessage;
|
import org.apache.spark.network.protocol.RequestMessage;
|
||||||
import org.apache.spark.network.protocol.ResponseMessage;
|
import org.apache.spark.network.protocol.ResponseMessage;
|
||||||
import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
|
import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
|
||||||
|
@ -48,7 +47,7 @@ import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
|
||||||
* on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
|
* on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
|
||||||
* timeout if the client is continuously sending but getting no responses, for simplicity.
|
* timeout if the client is continuously sending but getting no responses, for simplicity.
|
||||||
*/
|
*/
|
||||||
public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
|
public class TransportChannelHandler extends ChannelInboundHandlerAdapter {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
|
private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
|
||||||
|
|
||||||
private final TransportClient client;
|
private final TransportClient client;
|
||||||
|
@ -114,11 +113,13 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
|
public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception {
|
||||||
if (request instanceof RequestMessage) {
|
if (request instanceof RequestMessage) {
|
||||||
requestHandler.handle((RequestMessage) request);
|
requestHandler.handle((RequestMessage) request);
|
||||||
} else {
|
} else if (request instanceof ResponseMessage) {
|
||||||
responseHandler.handle((ResponseMessage) request);
|
responseHandler.handle((ResponseMessage) request);
|
||||||
|
} else {
|
||||||
|
ctx.fireChannelRead(request);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,11 +49,11 @@ import org.apache.spark.network.util.NettyUtils;
|
||||||
public class ProtocolSuite {
|
public class ProtocolSuite {
|
||||||
private void testServerToClient(Message msg) {
|
private void testServerToClient(Message msg) {
|
||||||
EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
|
EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
|
||||||
new MessageEncoder());
|
MessageEncoder.INSTANCE);
|
||||||
serverChannel.writeOutbound(msg);
|
serverChannel.writeOutbound(msg);
|
||||||
|
|
||||||
EmbeddedChannel clientChannel = new EmbeddedChannel(
|
EmbeddedChannel clientChannel = new EmbeddedChannel(
|
||||||
NettyUtils.createFrameDecoder(), new MessageDecoder());
|
NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
|
||||||
|
|
||||||
while (!serverChannel.outboundMessages().isEmpty()) {
|
while (!serverChannel.outboundMessages().isEmpty()) {
|
||||||
clientChannel.writeInbound(serverChannel.readOutbound());
|
clientChannel.writeInbound(serverChannel.readOutbound());
|
||||||
|
@ -65,11 +65,11 @@ public class ProtocolSuite {
|
||||||
|
|
||||||
private void testClientToServer(Message msg) {
|
private void testClientToServer(Message msg) {
|
||||||
EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
|
EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
|
||||||
new MessageEncoder());
|
MessageEncoder.INSTANCE);
|
||||||
clientChannel.writeOutbound(msg);
|
clientChannel.writeOutbound(msg);
|
||||||
|
|
||||||
EmbeddedChannel serverChannel = new EmbeddedChannel(
|
EmbeddedChannel serverChannel = new EmbeddedChannel(
|
||||||
NettyUtils.createFrameDecoder(), new MessageDecoder());
|
NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
|
||||||
|
|
||||||
while (!clientChannel.outboundMessages().isEmpty()) {
|
while (!clientChannel.outboundMessages().isEmpty()) {
|
||||||
serverChannel.writeInbound(clientChannel.readOutbound());
|
serverChannel.writeInbound(clientChannel.readOutbound());
|
||||||
|
|
|
@ -2608,12 +2608,8 @@ private[util] object CallerContext extends Logging {
|
||||||
val callerContextSupported: Boolean = {
|
val callerContextSupported: Boolean = {
|
||||||
SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
|
SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
|
||||||
try {
|
try {
|
||||||
// `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
|
Utils.classForName("org.apache.hadoop.ipc.CallerContext")
|
||||||
// master Maven build, so do not use it before resolving SPARK-17714.
|
Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
|
||||||
// scalastyle:off classforname
|
|
||||||
Class.forName("org.apache.hadoop.ipc.CallerContext")
|
|
||||||
Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
|
|
||||||
// scalastyle:on classforname
|
|
||||||
true
|
true
|
||||||
} catch {
|
} catch {
|
||||||
case _: ClassNotFoundException =>
|
case _: ClassNotFoundException =>
|
||||||
|
@ -2688,12 +2684,8 @@ private[spark] class CallerContext(
|
||||||
def setCurrentContext(): Unit = {
|
def setCurrentContext(): Unit = {
|
||||||
if (CallerContext.callerContextSupported) {
|
if (CallerContext.callerContextSupported) {
|
||||||
try {
|
try {
|
||||||
// `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
|
val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext")
|
||||||
// master Maven build, so do not use it before resolving SPARK-17714.
|
val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
|
||||||
// scalastyle:off classforname
|
|
||||||
val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext")
|
|
||||||
val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
|
|
||||||
// scalastyle:on classforname
|
|
||||||
val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
|
val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
|
||||||
val hdfsContext = builder.getMethod("build").invoke(builderInst)
|
val hdfsContext = builder.getMethod("build").invoke(builderInst)
|
||||||
callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)
|
callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)
|
||||||
|
|
Loading…
Reference in a new issue