[SPARK-25408] Move to more ideomatic Java8

While working on another PR, I noticed that there is quite some legacy Java in there that can be beautified. For example the use of features from Java8, such as:
- Collection libraries
- Try-with-resource blocks

No logic has been changed. I think it is important to have a solid codebase with examples that will inspire next PR's to follow up on the best practices.

What are your thoughts on this?

This makes code easier to read, and using try-with-resource makes is less likely to forget to close something.

## What changes were proposed in this pull request?

No changes in the logic of Spark, but more in the aesthetics of the code.

## How was this patch tested?

Using the existing unit tests. Since no logic is changed, the existing unit tests should pass.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Closes #22637 from Fokko/SPARK-25408.

Authored-by: Fokko Driesprong <fokkodriesprong@godatadriven.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
Fokko Driesprong 2018-10-08 09:58:52 -05:00 committed by Sean Owen
parent a853a80202
commit 1a28625355
17 changed files with 247 additions and 291 deletions

View file

@ -54,11 +54,8 @@ public class KVStoreSerializer {
return ((String) o).getBytes(UTF_8);
} else {
ByteArrayOutputStream bytes = new ByteArrayOutputStream();
GZIPOutputStream out = new GZIPOutputStream(bytes);
try {
try (GZIPOutputStream out = new GZIPOutputStream(bytes)) {
mapper.writeValue(out, o);
} finally {
out.close();
}
return bytes.toByteArray();
}
@ -69,11 +66,8 @@ public class KVStoreSerializer {
if (klass.equals(String.class)) {
return (T) new String(data, UTF_8);
} else {
GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data));
try {
try (GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data))) {
return mapper.readValue(in, klass);
} finally {
in.close();
}
}
}

View file

@ -217,7 +217,7 @@ public class LevelDBSuite {
public void testNegativeIndexValues() throws Exception {
List<Integer> expected = Arrays.asList(-100, -50, 0, 50, 100);
expected.stream().forEach(i -> {
expected.forEach(i -> {
try {
db.write(createCustomType1(i));
} catch (Exception e) {

View file

@ -143,37 +143,39 @@ public class ChunkFetchIntegrationSuite {
}
private FetchResult fetchChunks(List<Integer> chunkIndices) throws Exception {
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
final Semaphore sem = new Semaphore(0);
final FetchResult res = new FetchResult();
res.successChunks = Collections.synchronizedSet(new HashSet<Integer>());
res.failedChunks = Collections.synchronizedSet(new HashSet<Integer>());
res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
ChunkReceivedCallback callback = new ChunkReceivedCallback() {
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
buffer.retain();
res.successChunks.add(chunkIndex);
res.buffers.add(buffer);
sem.release();
try (TransportClient client =
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort())) {
final Semaphore sem = new Semaphore(0);
res.successChunks = Collections.synchronizedSet(new HashSet<Integer>());
res.failedChunks = Collections.synchronizedSet(new HashSet<Integer>());
res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
ChunkReceivedCallback callback = new ChunkReceivedCallback() {
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
buffer.retain();
res.successChunks.add(chunkIndex);
res.buffers.add(buffer);
sem.release();
}
@Override
public void onFailure(int chunkIndex, Throwable e) {
res.failedChunks.add(chunkIndex);
sem.release();
}
};
for (int chunkIndex : chunkIndices) {
client.fetchChunk(STREAM_ID, chunkIndex, callback);
}
@Override
public void onFailure(int chunkIndex, Throwable e) {
res.failedChunks.add(chunkIndex);
sem.release();
if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) {
fail("Timeout getting response from the server");
}
};
for (int chunkIndex : chunkIndices) {
client.fetchChunk(STREAM_ID, chunkIndex, callback);
}
if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) {
fail("Timeout getting response from the server");
}
client.close();
return res;
}

View file

@ -37,14 +37,8 @@ public class ShuffleIndexInformation {
size = (int)indexFile.length();
ByteBuffer buffer = ByteBuffer.allocate(size);
offsets = buffer.asLongBuffer();
DataInputStream dis = null;
try {
dis = new DataInputStream(Files.newInputStream(indexFile.toPath()));
try (DataInputStream dis = new DataInputStream(Files.newInputStream(indexFile.toPath()))) {
dis.readFully(buffer.array());
} finally {
if (dis != null) {
dis.close();
}
}
}

View file

@ -98,19 +98,19 @@ public class ExternalShuffleBlockResolverSuite {
resolver.registerExecutor("app0", "exec0",
dataContext.createExecutorInfo(SORT_MANAGER));
InputStream block0Stream =
resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream();
String block0 = CharStreams.toString(
new InputStreamReader(block0Stream, StandardCharsets.UTF_8));
block0Stream.close();
assertEquals(sortBlock0, block0);
try (InputStream block0Stream = resolver.getBlockData(
"app0", "exec0", 0, 0, 0).createInputStream()) {
String block0 =
CharStreams.toString(new InputStreamReader(block0Stream, StandardCharsets.UTF_8));
assertEquals(sortBlock0, block0);
}
InputStream block1Stream =
resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream();
String block1 = CharStreams.toString(
new InputStreamReader(block1Stream, StandardCharsets.UTF_8));
block1Stream.close();
assertEquals(sortBlock1, block1);
try (InputStream block1Stream = resolver.getBlockData(
"app0", "exec0", 0, 0, 1).createInputStream()) {
String block1 =
CharStreams.toString(new InputStreamReader(block1Stream, StandardCharsets.UTF_8));
assertEquals(sortBlock1, block1);
}
}
@Test
@ -149,7 +149,7 @@ public class ExternalShuffleBlockResolverSuite {
private void assertPathsMatch(String p1, String p2, String p3, String expectedPathname) {
String normPathname =
ExternalShuffleBlockResolver.createNormalizedInternedPathname(p1, p2, p3);
ExternalShuffleBlockResolver.createNormalizedInternedPathname(p1, p2, p3);
assertEquals(expectedPathname, normPathname);
File file = new File(normPathname);
String returnedPath = file.getPath();

View file

@ -133,37 +133,37 @@ public class ExternalShuffleIntegrationSuite {
final Semaphore requestsRemaining = new Semaphore(0);
ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000);
client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
@Override
public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
synchronized (this) {
if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) {
data.retain();
res.successBlocks.add(blockId);
res.buffers.add(data);
requestsRemaining.release();
try (ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000)) {
client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
@Override
public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
synchronized (this) {
if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) {
data.retain();
res.successBlocks.add(blockId);
res.buffers.add(data);
requestsRemaining.release();
}
}
}
}
@Override
public void onBlockFetchFailure(String blockId, Throwable exception) {
synchronized (this) {
if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) {
res.failedBlocks.add(blockId);
requestsRemaining.release();
@Override
public void onBlockFetchFailure(String blockId, Throwable exception) {
synchronized (this) {
if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) {
res.failedBlocks.add(blockId);
requestsRemaining.release();
}
}
}
}
}, null);
}, null);
if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
fail("Timeout getting response from the server");
if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
fail("Timeout getting response from the server");
}
}
client.close();
return res;
}

View file

@ -96,14 +96,16 @@ public class ExternalShuffleSecuritySuite {
ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true")));
}
ExternalShuffleClient client =
new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true, 5000);
client.init(appId);
// Registration either succeeds or throws an exception.
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0",
new ExecutorShuffleInfo(new String[0], 0,
"org.apache.spark.shuffle.sort.SortShuffleManager"));
client.close();
try (ExternalShuffleClient client =
new ExternalShuffleClient(
testConf, new TestSecretKeyHolder(appId, secretKey), true, 5000)) {
client.init(appId);
// Registration either succeeds or throws an exception.
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0",
new ExecutorShuffleInfo(
new String[0], 0, "org.apache.spark.shuffle.sort.SortShuffleManager")
);
}
}
/** Provides a secret key holder which always returns the given secret key, for a single appId. */

View file

@ -191,10 +191,9 @@ public abstract class CountMinSketch {
* Reads in a {@link CountMinSketch} from a byte array.
*/
public static CountMinSketch readFrom(byte[] bytes) throws IOException {
InputStream in = new ByteArrayInputStream(bytes);
CountMinSketch cms = readFrom(in);
in.close();
return cms;
try (InputStream in = new ByteArrayInputStream(bytes)) {
return readFrom(in);
}
}
/**

View file

@ -322,10 +322,10 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
@Override
public byte[] toByteArray() throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
writeTo(out);
out.close();
return out.toByteArray();
try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
writeTo(out);
return out.toByteArray();
}
}
public static CountMinSketchImpl readFrom(InputStream in) throws IOException {

View file

@ -135,62 +135,58 @@ public class ReadAheadInputStream extends InputStream {
} finally {
stateChangeLock.unlock();
}
executorService.execute(new Runnable() {
@Override
public void run() {
stateChangeLock.lock();
try {
if (isClosed) {
readInProgress = false;
return;
}
// Flip this so that the close method will not close the underlying input stream when we
// are reading.
isReading = true;
} finally {
stateChangeLock.unlock();
}
// Please note that it is safe to release the lock and read into the read ahead buffer
// because either of following two conditions will hold - 1. The active buffer has
// data available to read so the reader will not read from the read ahead buffer.
// 2. This is the first time read is called or the active buffer is exhausted,
// in that case the reader waits for this async read to complete.
// So there is no race condition in both the situations.
int read = 0;
int off = 0, len = arr.length;
Throwable exception = null;
try {
// try to fill the read ahead buffer.
// if a reader is waiting, possibly return early.
do {
read = underlyingInputStream.read(arr, off, len);
if (read <= 0) break;
off += read;
len -= read;
} while (len > 0 && !isWaiting.get());
} catch (Throwable ex) {
exception = ex;
if (ex instanceof Error) {
// `readException` may not be reported to the user. Rethrow Error to make sure at least
// The user can see Error in UncaughtExceptionHandler.
throw (Error) ex;
}
} finally {
stateChangeLock.lock();
readAheadBuffer.limit(off);
if (read < 0 || (exception instanceof EOFException)) {
endOfStream = true;
} else if (exception != null) {
readAborted = true;
readException = exception;
}
executorService.execute(() -> {
stateChangeLock.lock();
try {
if (isClosed) {
readInProgress = false;
signalAsyncReadComplete();
stateChangeLock.unlock();
closeUnderlyingInputStreamIfNecessary();
return;
}
// Flip this so that the close method will not close the underlying input stream when we
// are reading.
isReading = true;
} finally {
stateChangeLock.unlock();
}
// Please note that it is safe to release the lock and read into the read ahead buffer
// because either of following two conditions will hold - 1. The active buffer has
// data available to read so the reader will not read from the read ahead buffer.
// 2. This is the first time read is called or the active buffer is exhausted,
// in that case the reader waits for this async read to complete.
// So there is no race condition in both the situations.
int read = 0;
int off = 0, len = arr.length;
Throwable exception = null;
try {
// try to fill the read ahead buffer.
// if a reader is waiting, possibly return early.
do {
read = underlyingInputStream.read(arr, off, len);
if (read <= 0) break;
off += read;
len -= read;
} while (len > 0 && !isWaiting.get());
} catch (Throwable ex) {
exception = ex;
if (ex instanceof Error) {
// `readException` may not be reported to the user. Rethrow Error to make sure at least
// The user can see Error in UncaughtExceptionHandler.
throw (Error) ex;
}
} finally {
stateChangeLock.lock();
readAheadBuffer.limit(off);
if (read < 0 || (exception instanceof EOFException)) {
endOfStream = true;
} else if (exception != null) {
readAborted = true;
readException = exception;
}
readInProgress = false;
signalAsyncReadComplete();
stateChangeLock.unlock();
closeUnderlyingInputStreamIfNecessary();
}
});
}

View file

@ -152,9 +152,9 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
}
for (int i = 0; i < numPartitions; i++) {
final DiskBlockObjectWriter writer = partitionWriters[i];
partitionWriterSegments[i] = writer.commitAndGet();
writer.close();
try (DiskBlockObjectWriter writer = partitionWriters[i]) {
partitionWriterSegments[i] = writer.commitAndGet();
}
}
File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);

View file

@ -181,42 +181,43 @@ final class ShuffleExternalSorter extends MemoryConsumer {
// around this, we pass a dummy no-op serializer.
final SerializerInstance ser = DummySerializerInstance.INSTANCE;
final DiskBlockObjectWriter writer =
blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
int currentPartition = -1;
final int uaoSize = UnsafeAlignedOffset.getUaoSize();
while (sortedRecords.hasNext()) {
sortedRecords.loadNext();
final int partition = sortedRecords.packedRecordPointer.getPartitionId();
assert (partition >= currentPartition);
if (partition != currentPartition) {
// Switch to the new partition
if (currentPartition != -1) {
final FileSegment fileSegment = writer.commitAndGet();
spillInfo.partitionLengths[currentPartition] = fileSegment.length();
final FileSegment committedSegment;
try (DiskBlockObjectWriter writer =
blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse)) {
final int uaoSize = UnsafeAlignedOffset.getUaoSize();
while (sortedRecords.hasNext()) {
sortedRecords.loadNext();
final int partition = sortedRecords.packedRecordPointer.getPartitionId();
assert (partition >= currentPartition);
if (partition != currentPartition) {
// Switch to the new partition
if (currentPartition != -1) {
final FileSegment fileSegment = writer.commitAndGet();
spillInfo.partitionLengths[currentPartition] = fileSegment.length();
}
currentPartition = partition;
}
currentPartition = partition;
final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
final Object recordPage = taskMemoryManager.getPage(recordPointer);
final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage);
long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length
while (dataRemaining > 0) {
final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining);
Platform.copyMemory(
recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
writer.write(writeBuffer, 0, toTransfer);
recordReadPosition += toTransfer;
dataRemaining -= toTransfer;
}
writer.recordWritten();
}
final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
final Object recordPage = taskMemoryManager.getPage(recordPointer);
final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage);
long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length
while (dataRemaining > 0) {
final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining);
Platform.copyMemory(
recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
writer.write(writeBuffer, 0, toTransfer);
recordReadPosition += toTransfer;
dataRemaining -= toTransfer;
}
writer.recordWritten();
committedSegment = writer.commitAndGet();
}
final FileSegment committedSegment = writer.commitAndGet();
writer.close();
// If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
// then the file might be empty. Note that it might be better to avoid calling
// writeSortedFile() in that case.

View file

@ -39,30 +39,28 @@ public class JavaJdbcRDDSuite implements Serializable {
sc = new JavaSparkContext("local", "JavaAPISuite");
Class.forName("org.apache.derby.jdbc.EmbeddedDriver");
Connection connection =
DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true");
try {
Statement create = connection.createStatement();
create.execute(
"CREATE TABLE FOO(" +
"ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," +
"DATA INTEGER)");
create.close();
try (Connection connection = DriverManager.getConnection(
"jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true")) {
PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)");
for (int i = 1; i <= 100; i++) {
insert.setInt(1, i * 2);
insert.executeUpdate();
try (Statement create = connection.createStatement()) {
create.execute(
"CREATE TABLE FOO(ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY" +
" (START WITH 1, INCREMENT BY 1), DATA INTEGER)");
}
try (PreparedStatement insert = connection.prepareStatement(
"INSERT INTO FOO(DATA) VALUES(?)")) {
for (int i = 1; i <= 100; i++) {
insert.setInt(1, i * 2);
insert.executeUpdate();
}
}
insert.close();
} catch (SQLException e) {
// If table doesn't exist...
if (e.getSQLState().compareTo("X0Y32") != 0) {
throw e;
}
} finally {
connection.close();
}
}

View file

@ -186,14 +186,14 @@ public class UnsafeShuffleWriterSuite {
if (conf.getBoolean("spark.shuffle.compress", true)) {
in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
}
DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
while (records.hasNext()) {
Tuple2<Object, Object> record = records.next();
assertEquals(i, hashPartitioner.getPartition(record._1()));
recordsList.add(record);
try (DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in)) {
Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
while (records.hasNext()) {
Tuple2<Object, Object> record = records.next();
assertEquals(i, hashPartitioner.getPartition(record._1()));
recordsList.add(record);
}
}
recordsStream.close();
startOffset += partitionSize;
}
}

View file

@ -997,10 +997,10 @@ public class JavaAPISuite implements Serializable {
FileOutputStream fos1 = new FileOutputStream(file1);
FileChannel channel1 = fos1.getChannel();
ByteBuffer bbuf = ByteBuffer.wrap(content1);
channel1.write(bbuf);
channel1.close();
try (FileChannel channel1 = fos1.getChannel()) {
ByteBuffer bbuf = ByteBuffer.wrap(content1);
channel1.write(bbuf);
}
JavaPairRDD<String, PortableDataStream> readRDD = sc.binaryFiles(tempDirName, 3);
List<Tuple2<String, PortableDataStream>> result = readRDD.collect();
for (Tuple2<String, PortableDataStream> res : result) {
@ -1018,10 +1018,10 @@ public class JavaAPISuite implements Serializable {
FileOutputStream fos1 = new FileOutputStream(file1);
FileChannel channel1 = fos1.getChannel();
ByteBuffer bbuf = ByteBuffer.wrap(content1);
channel1.write(bbuf);
channel1.close();
try (FileChannel channel1 = fos1.getChannel()) {
ByteBuffer bbuf = ByteBuffer.wrap(content1);
channel1.write(bbuf);
}
JavaPairRDD<String, PortableDataStream> readRDD = sc.binaryFiles(tempDirName).cache();
readRDD.foreach(pair -> pair._2().toArray()); // force the file to read
@ -1042,13 +1042,12 @@ public class JavaAPISuite implements Serializable {
FileOutputStream fos1 = new FileOutputStream(file1);
FileChannel channel1 = fos1.getChannel();
for (int i = 0; i < numOfCopies; i++) {
ByteBuffer bbuf = ByteBuffer.wrap(content1);
channel1.write(bbuf);
try (FileChannel channel1 = fos1.getChannel()) {
for (int i = 0; i < numOfCopies; i++) {
ByteBuffer bbuf = ByteBuffer.wrap(content1);
channel1.write(bbuf);
}
}
channel1.close();
JavaRDD<byte[]> readRDD = sc.binaryRecords(tempDirName, content1.length);
assertEquals(numOfCopies,readRDD.count());

View file

@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.catalyst.expressions;
import java.io.Closeable;
import java.io.IOException;
import org.apache.spark.memory.MemoryConsumer;
@ -45,7 +46,7 @@ import org.slf4j.LoggerFactory;
* page requires an average size for key value pairs to be larger than 1024 bytes.
*
*/
public abstract class RowBasedKeyValueBatch extends MemoryConsumer {
public abstract class RowBasedKeyValueBatch extends MemoryConsumer implements Closeable {
protected final Logger logger = LoggerFactory.getLogger(RowBasedKeyValueBatch.class);
private static final int DEFAULT_CAPACITY = 1 << 16;

View file

@ -123,9 +123,8 @@ public class RowBasedKeyValueBatchSuite {
@Test
public void emptyBatch() throws Exception {
RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
try {
try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
Assert.assertEquals(0, batch.numRows());
try {
batch.getKeyRow(-1);
@ -152,31 +151,24 @@ public class RowBasedKeyValueBatchSuite {
// Expected exception; do nothing.
}
Assert.assertFalse(batch.rowIterator().next());
} finally {
batch.close();
}
}
@Test
public void batchType() throws Exception {
RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
try {
public void batchType() {
try (RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
Assert.assertEquals(batch1.getClass(), VariableLengthRowBasedKeyValueBatch.class);
Assert.assertEquals(batch2.getClass(), FixedLengthRowBasedKeyValueBatch.class);
} finally {
batch1.close();
batch2.close();
}
}
@Test
public void setAndRetrieve() {
RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
try {
try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
UnsafeRow ret1 = appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
Assert.assertTrue(checkValue(ret1, 1, 1));
UnsafeRow ret2 = appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2));
@ -204,33 +196,27 @@ public class RowBasedKeyValueBatchSuite {
} catch (AssertionError e) {
// Expected exception; do nothing.
}
} finally {
batch.close();
}
}
@Test
public void setUpdateAndRetrieve() {
RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
try {
try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
Assert.assertEquals(1, batch.numRows());
UnsafeRow retrievedValue = batch.getValueRow(0);
updateValueRow(retrievedValue, 2, 2);
UnsafeRow retrievedValue2 = batch.getValueRow(0);
Assert.assertTrue(checkValue(retrievedValue2, 2, 2));
} finally {
batch.close();
}
}
@Test
public void iteratorTest() throws Exception {
RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
try {
try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2));
appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3));
@ -253,16 +239,13 @@ public class RowBasedKeyValueBatchSuite {
Assert.assertTrue(checkKey(key3, 3, "C"));
Assert.assertTrue(checkValue(value3, 3, 3));
Assert.assertFalse(iterator.next());
} finally {
batch.close();
}
}
@Test
public void fixedLengthTest() throws Exception {
RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
try {
try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
appendRow(batch, makeKeyRow(11, 11), makeValueRow(1, 1));
appendRow(batch, makeKeyRow(22, 22), makeValueRow(2, 2));
appendRow(batch, makeKeyRow(33, 33), makeValueRow(3, 3));
@ -293,16 +276,13 @@ public class RowBasedKeyValueBatchSuite {
Assert.assertTrue(checkKey(key3, 33, 33));
Assert.assertTrue(checkValue(value3, 3, 3));
Assert.assertFalse(iterator.next());
} finally {
batch.close();
}
}
@Test
public void appendRowUntilExceedingCapacity() throws Exception {
RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, 10);
try {
try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, 10)) {
UnsafeRow key = makeKeyRow(1, "A");
UnsafeRow value = makeValueRow(1, 1);
for (int i = 0; i < 10; i++) {
@ -321,8 +301,6 @@ public class RowBasedKeyValueBatchSuite {
Assert.assertTrue(checkValue(value1, 1, 1));
}
Assert.assertFalse(iterator.next());
} finally {
batch.close();
}
}
@ -330,9 +308,8 @@ public class RowBasedKeyValueBatchSuite {
public void appendRowUntilExceedingPageSize() throws Exception {
// Use default size or spark.buffer.pageSize if specified
int pageSizeToUse = (int) memoryManager.pageSizeBytes();
RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, pageSizeToUse); //enough capacity
try {
try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, pageSizeToUse)) {
UnsafeRow key = makeKeyRow(1, "A");
UnsafeRow value = makeValueRow(1, 1);
int recordLength = 8 + key.getSizeInBytes() + value.getSizeInBytes() + 8;
@ -356,49 +333,44 @@ public class RowBasedKeyValueBatchSuite {
Assert.assertTrue(checkValue(value1, 1, 1));
}
Assert.assertFalse(iterator.next());
} finally {
batch.close();
}
}
@Test
public void failureToAllocateFirstPage() throws Exception {
memoryManager.limit(1024);
RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
try {
try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
UnsafeRow key = makeKeyRow(1, "A");
UnsafeRow value = makeValueRow(11, 11);
UnsafeRow ret = appendRow(batch, key, value);
Assert.assertNull(ret);
Assert.assertFalse(batch.rowIterator().next());
} finally {
batch.close();
}
}
@Test
public void randomizedTest() {
RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
int numEntry = 100;
long[] expectedK1 = new long[numEntry];
String[] expectedK2 = new String[numEntry];
long[] expectedV1 = new long[numEntry];
long[] expectedV2 = new long[numEntry];
try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
int numEntry = 100;
long[] expectedK1 = new long[numEntry];
String[] expectedK2 = new String[numEntry];
long[] expectedV1 = new long[numEntry];
long[] expectedV2 = new long[numEntry];
for (int i = 0; i < numEntry; i++) {
long k1 = rand.nextLong();
String k2 = getRandomString(rand.nextInt(256));
long v1 = rand.nextLong();
long v2 = rand.nextLong();
appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2));
expectedK1[i] = k1;
expectedK2[i] = k2;
expectedV1[i] = v1;
expectedV2[i] = v2;
}
for (int i = 0; i < numEntry; i++) {
long k1 = rand.nextLong();
String k2 = getRandomString(rand.nextInt(256));
long v1 = rand.nextLong();
long v2 = rand.nextLong();
appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2));
expectedK1[i] = k1;
expectedK2[i] = k2;
expectedV1[i] = v1;
expectedV2[i] = v2;
}
try {
for (int j = 0; j < 10000; j++) {
int rowId = rand.nextInt(numEntry);
if (rand.nextBoolean()) {
@ -410,8 +382,6 @@ public class RowBasedKeyValueBatchSuite {
Assert.assertTrue(checkValue(value, expectedV1[rowId], expectedV2[rowId]));
}
}
} finally {
batch.close();
}
}
}