[SPARK-2936] Migrate Netty network module from Java to Scala

The Netty network module was originally written when Scala 2.9.x had a bug that prevents a pure Scala implementation, and a subset of the files were done in Java. We have since upgraded to Scala 2.10, and can migrate all Java files now to Scala.

https://github.com/netty/netty/issues/781

https://github.com/mesos/spark/pull/522

Author: Reynold Xin <rxin@apache.org>

Closes #1865 from rxin/netty and squashes the following commits:

332422f [Reynold Xin] Code review feedback
ca9eeee [Reynold Xin] Minor update.
7f1434b [Reynold Xin] [SPARK-2936] Migrate Netty network module from Java to Scala
This commit is contained in:
Reynold Xin 2014-08-10 20:36:54 -07:00 committed by Aaron Davidson
parent b715aa0c80
commit ba28a8fcbc
12 changed files with 292 additions and 364 deletions

View file

@ -1,100 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty;
import java.util.concurrent.TimeUnit;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class FileClient {
private static final Logger LOG = LoggerFactory.getLogger(FileClient.class.getName());
private final FileClientHandler handler;
private Channel channel = null;
private Bootstrap bootstrap = null;
private EventLoopGroup group = null;
private final int connectTimeout;
private final int sendTimeout = 60; // 1 min
FileClient(FileClientHandler handler, int connectTimeout) {
this.handler = handler;
this.connectTimeout = connectTimeout;
}
public void init() {
group = new OioEventLoopGroup();
bootstrap = new Bootstrap();
bootstrap.group(group)
.channel(OioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
.handler(new FileClientChannelInitializer(handler));
}
public void connect(String host, int port) {
try {
// Start the connection attempt.
channel = bootstrap.connect(host, port).sync().channel();
// ChannelFuture cf = channel.closeFuture();
//cf.addListener(new ChannelCloseListener(this));
} catch (InterruptedException e) {
LOG.warn("FileClient interrupted while trying to connect", e);
close();
}
}
public void waitForClose() {
try {
channel.closeFuture().sync();
} catch (InterruptedException e) {
LOG.warn("FileClient interrupted", e);
}
}
public void sendRequest(String file) {
//assert(file == null);
//assert(channel == null);
try {
// Should be able to send the message to network link channel.
boolean bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS);
if (!bSent) {
throw new RuntimeException("Failed to send");
}
} catch (InterruptedException e) {
LOG.error("Error", e);
}
}
public void close() {
if (group != null) {
group.shutdownGracefully();
group = null;
bootstrap = null;
}
}
}

View file

@ -1,111 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty;
import java.net.InetSocketAddress;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioServerSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Server that accept the path of a file an echo back its content.
*/
class FileServer {
private static final Logger LOG = LoggerFactory.getLogger(FileServer.class.getName());
private EventLoopGroup bossGroup = null;
private EventLoopGroup workerGroup = null;
private ChannelFuture channelFuture = null;
private int port = 0;
FileServer(PathResolver pResolver, int port) {
InetSocketAddress addr = new InetSocketAddress(port);
// Configure the server.
bossGroup = new OioEventLoopGroup();
workerGroup = new OioEventLoopGroup();
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(bossGroup, workerGroup)
.channel(OioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 100)
.option(ChannelOption.SO_RCVBUF, 1500)
.childHandler(new FileServerChannelInitializer(pResolver));
// Start the server.
channelFuture = bootstrap.bind(addr);
try {
// Get the address we bound to.
InetSocketAddress boundAddress =
((InetSocketAddress) channelFuture.sync().channel().localAddress());
this.port = boundAddress.getPort();
} catch (InterruptedException ie) {
this.port = 0;
}
}
/**
* Start the file server asynchronously in a new thread.
*/
public void start() {
Thread blockingThread = new Thread() {
@Override
public void run() {
try {
channelFuture.channel().closeFuture().sync();
LOG.info("FileServer exiting");
} catch (InterruptedException e) {
LOG.error("File server start got interrupted", e);
}
// NOTE: bootstrap is shutdown in stop()
}
};
blockingThread.setDaemon(true);
blockingThread.start();
}
public int getPort() {
return port;
}
public void stop() {
// Close the bound channel.
if (channelFuture != null) {
channelFuture.channel().close().awaitUninterruptibly();
channelFuture = null;
}
// Shutdown event groups
if (bossGroup != null) {
bossGroup.shutdownGracefully();
bossGroup = null;
}
if (workerGroup != null) {
workerGroup.shutdownGracefully();
workerGroup = null;
}
// TODO: Shutdown all accepted channels as well ?
}
}

View file

@ -1,83 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty;
import java.io.File;
import java.io.FileInputStream;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.DefaultFileRegion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.FileSegment;
class FileServerHandler extends SimpleChannelInboundHandler<String> {
private static final Logger LOG = LoggerFactory.getLogger(FileServerHandler.class.getName());
private final PathResolver pResolver;
FileServerHandler(PathResolver pResolver){
this.pResolver = pResolver;
}
@Override
public void channelRead0(ChannelHandlerContext ctx, String blockIdString) {
BlockId blockId = BlockId.apply(blockIdString);
FileSegment fileSegment = pResolver.getBlockLocation(blockId);
// if getBlockLocation returns null, close the channel
if (fileSegment == null) {
//ctx.close();
return;
}
File file = fileSegment.file();
if (file.exists()) {
if (!file.isFile()) {
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
long length = fileSegment.length();
if (length > Integer.MAX_VALUE || length <= 0) {
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = (int) length;
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.write(new DefaultFileRegion(new FileInputStream(file)
.getChannel(), fileSegment.offset(), fileSegment.length()));
} catch (Exception e) {
LOG.error("Exception: ", e);
}
} else {
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LOG.error("Exception: ", cause);
ctx.close();
}
}

View file

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty
import java.util.concurrent.TimeUnit
import io.netty.bootstrap.Bootstrap
import io.netty.channel.{Channel, ChannelOption, EventLoopGroup}
import io.netty.channel.oio.OioEventLoopGroup
import io.netty.channel.socket.oio.OioSocketChannel
import org.apache.spark.Logging
class FileClient(handler: FileClientHandler, connectTimeout: Int) extends Logging {
private var channel: Channel = _
private var bootstrap: Bootstrap = _
private var group: EventLoopGroup = _
private val sendTimeout = 60
def init(): Unit = {
group = new OioEventLoopGroup
bootstrap = new Bootstrap
bootstrap.group(group)
.channel(classOf[OioSocketChannel])
.option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
.option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Integer.valueOf(connectTimeout))
.handler(new FileClientChannelInitializer(handler))
}
def connect(host: String, port: Int) {
try {
channel = bootstrap.connect(host, port).sync().channel()
} catch {
case e: InterruptedException =>
logWarning("FileClient interrupted while trying to connect", e)
close()
}
}
def waitForClose(): Unit = {
try {
channel.closeFuture.sync()
} catch {
case e: InterruptedException =>
logWarning("FileClient interrupted", e)
}
}
def sendRequest(file: String): Unit = {
try {
val bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS)
if (!bSent) {
throw new RuntimeException("Failed to send")
}
} catch {
case e: InterruptedException =>
logError("Error", e)
}
}
def close(): Unit = {
if (group != null) {
group.shutdownGracefully()
group = null
bootstrap = null
}
}
}

View file

@ -15,25 +15,17 @@
* limitations under the License.
*/
package org.apache.spark.network.netty;
package org.apache.spark.network.netty
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.string.StringEncoder;
import io.netty.channel.ChannelInitializer
import io.netty.channel.socket.SocketChannel
import io.netty.handler.codec.string.StringEncoder
class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> {
private final FileClientHandler fhandler;
class FileClientChannelInitializer(handler: FileClientHandler)
extends ChannelInitializer[SocketChannel] {
FileClientChannelInitializer(FileClientHandler handler) {
fhandler = handler;
}
@Override
public void initChannel(SocketChannel channel) {
// file no more than 2G
channel.pipeline()
.addLast("encoder", new StringEncoder())
.addLast("handler", fhandler);
def initChannel(channel: SocketChannel) {
channel.pipeline.addLast("encoder", new StringEncoder).addLast("handler", handler)
}
}

View file

@ -15,41 +15,36 @@
* limitations under the License.
*/
package org.apache.spark.network.netty;
package org.apache.spark.network.netty
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.buffer.ByteBuf
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockId
abstract class FileClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
private FileHeader currentHeader = null;
abstract class FileClientHandler extends SimpleChannelInboundHandler[ByteBuf] {
private volatile boolean handlerCalled = false;
private var currentHeader: FileHeader = null
public boolean isComplete() {
return handlerCalled;
}
@volatile
private var handlerCalled: Boolean = false
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
public abstract void handleError(BlockId blockId);
def isComplete: Boolean = handlerCalled
@Override
public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) {
// get header
if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) {
currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE()));
def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader)
def handleError(blockId: BlockId)
override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) {
if (currentHeader == null && in.readableBytes >= FileHeader.HEADER_SIZE) {
currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE))
}
// get file
if(in.readableBytes() >= currentHeader.fileLen()) {
handle(ctx, in, currentHeader);
handlerCalled = true;
currentHeader = null;
ctx.close();
if (in.readableBytes >= currentHeader.fileLen) {
handle(ctx, in, currentHeader)
handlerCalled = true
currentHeader = null
ctx.close()
}
}
}

View file

@ -26,7 +26,7 @@ private[spark] class FileHeader (
val fileLen: Int,
val blockId: BlockId) extends Logging {
lazy val buffer = {
lazy val buffer: ByteBuf = {
val buf = Unpooled.buffer()
buf.capacity(FileHeader.HEADER_SIZE)
buf.writeInt(fileLen)
@ -62,11 +62,10 @@ private[spark] object FileHeader {
new FileHeader(length, blockId)
}
def main (args:Array[String]) {
def main(args:Array[String]) {
val header = new FileHeader(25, TestBlockId("my_block"))
val buf = header.buffer
val newHeader = FileHeader.create(buf)
System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen)
}
}

View file

@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty
import java.net.InetSocketAddress
import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.{ChannelFuture, ChannelOption, EventLoopGroup}
import io.netty.channel.oio.OioEventLoopGroup
import io.netty.channel.socket.oio.OioServerSocketChannel
import org.apache.spark.Logging
/**
* Server that accept the path of a file an echo back its content.
*/
class FileServer(pResolver: PathResolver, private var port: Int) extends Logging {
private val addr: InetSocketAddress = new InetSocketAddress(port)
private var bossGroup: EventLoopGroup = new OioEventLoopGroup
private var workerGroup: EventLoopGroup = new OioEventLoopGroup
private var channelFuture: ChannelFuture = {
val bootstrap = new ServerBootstrap
bootstrap.group(bossGroup, workerGroup)
.channel(classOf[OioServerSocketChannel])
.option(ChannelOption.SO_BACKLOG, java.lang.Integer.valueOf(100))
.option(ChannelOption.SO_RCVBUF, java.lang.Integer.valueOf(1500))
.childHandler(new FileServerChannelInitializer(pResolver))
bootstrap.bind(addr)
}
try {
val boundAddress = channelFuture.sync.channel.localAddress.asInstanceOf[InetSocketAddress]
port = boundAddress.getPort
} catch {
case ie: InterruptedException =>
port = 0
}
/** Start the file server asynchronously in a new thread. */
def start(): Unit = {
val blockingThread: Thread = new Thread {
override def run(): Unit = {
try {
channelFuture.channel.closeFuture.sync
logInfo("FileServer exiting")
} catch {
case e: InterruptedException =>
logError("File server start got interrupted", e)
}
// NOTE: bootstrap is shutdown in stop()
}
}
blockingThread.setDaemon(true)
blockingThread.start()
}
def getPort: Int = port
def stop(): Unit = {
if (channelFuture != null) {
channelFuture.channel().close().awaitUninterruptibly()
channelFuture = null
}
if (bossGroup != null) {
bossGroup.shutdownGracefully()
bossGroup = null
}
if (workerGroup != null) {
workerGroup.shutdownGracefully()
workerGroup = null
}
}
}

View file

@ -15,27 +15,20 @@
* limitations under the License.
*/
package org.apache.spark.network.netty;
package org.apache.spark.network.netty
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.DelimiterBasedFrameDecoder;
import io.netty.handler.codec.Delimiters;
import io.netty.handler.codec.string.StringDecoder;
import io.netty.channel.ChannelInitializer
import io.netty.channel.socket.SocketChannel
import io.netty.handler.codec.{DelimiterBasedFrameDecoder, Delimiters}
import io.netty.handler.codec.string.StringDecoder
class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> {
class FileServerChannelInitializer(pResolver: PathResolver)
extends ChannelInitializer[SocketChannel] {
private final PathResolver pResolver;
FileServerChannelInitializer(PathResolver pResolver) {
this.pResolver = pResolver;
}
@Override
public void initChannel(SocketChannel channel) {
channel.pipeline()
.addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter()))
.addLast("stringDecoder", new StringDecoder())
.addLast("handler", new FileServerHandler(pResolver));
override def initChannel(channel: SocketChannel): Unit = {
channel.pipeline
.addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter : _*))
.addLast("stringDecoder", new StringDecoder)
.addLast("handler", new FileServerHandler(pResolver))
}
}

View file

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty
import java.io.FileInputStream
import io.netty.channel.{DefaultFileRegion, ChannelHandlerContext, SimpleChannelInboundHandler}
import org.apache.spark.Logging
import org.apache.spark.storage.{BlockId, FileSegment}
class FileServerHandler(pResolver: PathResolver)
extends SimpleChannelInboundHandler[String] with Logging {
override def channelRead0(ctx: ChannelHandlerContext, blockIdString: String): Unit = {
val blockId: BlockId = BlockId(blockIdString)
val fileSegment: FileSegment = pResolver.getBlockLocation(blockId)
if (fileSegment == null) {
return
}
val file = fileSegment.file
if (file.exists) {
if (!file.isFile) {
ctx.write(new FileHeader(0, blockId).buffer)
ctx.flush()
return
}
val length: Long = fileSegment.length
if (length > Integer.MAX_VALUE || length <= 0) {
ctx.write(new FileHeader(0, blockId).buffer)
ctx.flush()
return
}
ctx.write(new FileHeader(length.toInt, blockId).buffer)
try {
val channel = new FileInputStream(file).getChannel
ctx.write(new DefaultFileRegion(channel, fileSegment.offset, fileSegment.length))
} catch {
case e: Exception =>
logError("Exception: ", e)
}
} else {
ctx.write(new FileHeader(0, blockId).buffer)
}
ctx.flush()
}
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
logError("Exception: ", cause)
ctx.close()
}
}

View file

@ -15,12 +15,11 @@
* limitations under the License.
*/
package org.apache.spark.network.netty;
package org.apache.spark.network.netty
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.FileSegment;
import org.apache.spark.storage.{BlockId, FileSegment}
public interface PathResolver {
trait PathResolver {
/** Get the file segment in which the given block resides. */
FileSegment getBlockLocation(BlockId blockId);
def getBlockLocation(blockId: BlockId): FileSegment
}

View file

@ -32,7 +32,7 @@ private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) ext
server.stop()
}
def port: Int = server.getPort()
def port: Int = server.getPort
}