[SPARK-10891][STREAMING][KINESIS] Add MessageHandler to KinesisUtils.createStream similar to Direct Kafka

This PR allows users to map a Kinesis `Record` to a generic `T` when creating a Kinesis stream. This is particularly useful, if you would like to do extra work with Kinesis metadata such as sequence number, and partition key.

TODO:
 - [x] add tests

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #8954 from brkyvz/kinesis-handler.
This commit is contained in:
Burak Yavuz 2015-10-25 21:18:35 -07:00 committed by Tathagata Das
parent 80279ac187
commit 63accc7962
9 changed files with 337 additions and 75 deletions

View file

@ -18,6 +18,7 @@
package org.apache.spark.streaming.kinesis
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.util.control.NonFatal
import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
@ -67,7 +68,7 @@ class KinesisBackedBlockRDDPartition(
* sequence numbers of the corresponding blocks.
*/
private[kinesis]
class KinesisBackedBlockRDD(
class KinesisBackedBlockRDD[T: ClassTag](
@transient sc: SparkContext,
val regionName: String,
val endpointUrl: String,
@ -75,8 +76,9 @@ class KinesisBackedBlockRDD(
@transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges],
@transient isBlockIdValid: Array[Boolean] = Array.empty,
val retryTimeoutMs: Int = 10000,
val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _,
val awsCredentialsOption: Option[SerializableAWSCredentials] = None
) extends BlockRDD[Array[Byte]](sc, blockIds) {
) extends BlockRDD[T](sc, blockIds) {
require(blockIds.length == arrayOfseqNumberRanges.length,
"Number of blockIds is not equal to the number of sequence number ranges")
@ -90,23 +92,23 @@ class KinesisBackedBlockRDD(
}
}
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition]
val blockId = partition.blockId
def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = {
def getBlockFromBlockManager(): Option[Iterator[T]] = {
logDebug(s"Read partition data of $this from block manager, block $blockId")
blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]])
blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]])
}
def getBlockFromKinesis(): Iterator[Array[Byte]] = {
val credenentials = awsCredentialsOption.getOrElse {
def getBlockFromKinesis(): Iterator[T] = {
val credentials = awsCredentialsOption.getOrElse {
new DefaultAWSCredentialsProviderChain().getCredentials()
}
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
new KinesisSequenceRangeIterator(
credenentials, endpointUrl, regionName, range, retryTimeoutMs)
new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName,
range, retryTimeoutMs).map(messageHandler)
}
}
if (partition.isBlockIdValid) {
@ -129,8 +131,7 @@ class KinesisSequenceRangeIterator(
endpointUrl: String,
regionId: String,
range: SequenceNumberRange,
retryTimeoutMs: Int
) extends NextIterator[Array[Byte]] with Logging {
retryTimeoutMs: Int) extends NextIterator[Record] with Logging {
private val client = new AmazonKinesisClient(credentials)
private val streamName = range.streamName
@ -142,8 +143,8 @@ class KinesisSequenceRangeIterator(
client.setEndpoint(endpointUrl, "kinesis", regionId)
override protected def getNext(): Array[Byte] = {
var nextBytes: Array[Byte] = null
override protected def getNext(): Record = {
var nextRecord: Record = null
if (toSeqNumberReceived) {
finished = true
} else {
@ -170,10 +171,7 @@ class KinesisSequenceRangeIterator(
} else {
// Get the record, copy the data into a byte array and remember its sequence number
val nextRecord: Record = internalIterator.next()
val byteBuffer = nextRecord.getData()
nextBytes = new Array[Byte](byteBuffer.remaining())
byteBuffer.get(nextBytes)
nextRecord = internalIterator.next()
lastSeqNumber = nextRecord.getSequenceNumber()
// If the this record's sequence number matches the stopping sequence number, then make sure
@ -182,9 +180,8 @@ class KinesisSequenceRangeIterator(
toSeqNumberReceived = true
}
}
}
nextBytes
nextRecord
}
override protected def close(): Unit = {

View file

@ -17,7 +17,10 @@
package org.apache.spark.streaming.kinesis
import scala.reflect.ClassTag
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.model.Record
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, StorageLevel}
@ -26,7 +29,7 @@ import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
import org.apache.spark.streaming.{Duration, StreamingContext, Time}
private[kinesis] class KinesisInputDStream(
private[kinesis] class KinesisInputDStream[T: ClassTag](
@transient _ssc: StreamingContext,
streamName: String,
endpointUrl: String,
@ -35,11 +38,12 @@ private[kinesis] class KinesisInputDStream(
checkpointAppName: String,
checkpointInterval: Duration,
storageLevel: StorageLevel,
messageHandler: Record => T,
awsCredentialsOption: Option[SerializableAWSCredentials]
) extends ReceiverInputDStream[Array[Byte]](_ssc) {
) extends ReceiverInputDStream[T](_ssc) {
private[streaming]
override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[Array[Byte]] = {
override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[T] = {
// This returns true even for when blockInfos is empty
val allBlocksHaveRanges = blockInfos.map { _.metadataOption }.forall(_.nonEmpty)
@ -56,6 +60,7 @@ private[kinesis] class KinesisInputDStream(
context.sc, regionName, endpointUrl, blockIds, seqNumRanges,
isBlockIdValid = isBlockIdValid,
retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt,
messageHandler = messageHandler,
awsCredentialsOption = awsCredentialsOption)
} else {
logWarning("Kinesis sequence number information was not present with some block metadata," +
@ -64,8 +69,8 @@ private[kinesis] class KinesisInputDStream(
}
}
override def getReceiver(): Receiver[Array[Byte]] = {
override def getReceiver(): Receiver[T] = {
new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream,
checkpointAppName, checkpointInterval, storageLevel, awsCredentialsOption)
checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption)
}
}

View file

@ -80,7 +80,7 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
* @param awsCredentialsOption Optional AWS credentials, used when user directly specifies
* the credentials
*/
private[kinesis] class KinesisReceiver(
private[kinesis] class KinesisReceiver[T](
val streamName: String,
endpointUrl: String,
regionName: String,
@ -88,8 +88,9 @@ private[kinesis] class KinesisReceiver(
checkpointAppName: String,
checkpointInterval: Duration,
storageLevel: StorageLevel,
awsCredentialsOption: Option[SerializableAWSCredentials]
) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver =>
messageHandler: Record => T,
awsCredentialsOption: Option[SerializableAWSCredentials])
extends Receiver[T](storageLevel) with Logging { receiver =>
/*
* =================================================================================
@ -202,12 +203,7 @@ private[kinesis] class KinesisReceiver(
/** Add records of the given shard to the current block being generated */
private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = {
if (records.size > 0) {
val dataIterator = records.iterator().asScala.map { record =>
val byteBuffer = record.getData()
val byteArray = new Array[Byte](byteBuffer.remaining())
byteBuffer.get(byteArray)
byteArray
}
val dataIterator = records.iterator().asScala.map(messageHandler)
val metadata = SequenceNumberRange(streamName, shardId,
records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber())
blockGenerator.addMultipleDataWithCallback(dataIterator, metadata)
@ -240,7 +236,7 @@ private[kinesis] class KinesisReceiver(
/** Store the block along with its associated ranges */
private def storeBlockWithRanges(
blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[Array[Byte]]): Unit = {
blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = {
val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId)
if (rangesToReportOption.isEmpty) {
stop("Error while storing block into Spark, could not find sequence number ranges " +
@ -325,7 +321,7 @@ private[kinesis] class KinesisReceiver(
/** Callback method called when a block is ready to be pushed / stored. */
def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
storeBlockWithRanges(blockId,
arrayBuffer.asInstanceOf[mutable.ArrayBuffer[Array[Byte]]])
arrayBuffer.asInstanceOf[mutable.ArrayBuffer[T]])
}
/** Callback called in case of any error in internal of the BlockGenerator */

View file

@ -41,8 +41,8 @@ import org.apache.spark.Logging
* @param checkpointState represents the checkpoint state including the next checkpoint time.
* It's injected here for mocking purposes.
*/
private[kinesis] class KinesisRecordProcessor(
receiver: KinesisReceiver,
private[kinesis] class KinesisRecordProcessor[T](
receiver: KinesisReceiver[T],
workerId: String,
checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging {

View file

@ -16,16 +16,120 @@
*/
package org.apache.spark.streaming.kinesis
import scala.reflect.ClassTag
import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.model.Record
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.{Duration, StreamingContext}
object KinesisUtils {
/**
* Create an input stream that pulls messages from a Kinesis stream.
* This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
*
* Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
* on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
* gets the AWS credentials.
*
* @param ssc StreamingContext object
* @param kinesisAppName Kinesis application name used by the Kinesis Client Library
* (KCL) to update DynamoDB
* @param streamName Kinesis stream name
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
* @param regionName Name of region used by the Kinesis Client Library (KCL) to update
* DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
* @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
* worker's initial starting position in the stream.
* The values are either the beginning of the stream
* per Kinesis' limit of 24 hours
* (InitialPositionInStream.TRIM_HORIZON) or
* the tip of the stream (InitialPositionInStream.LATEST).
* @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
* See the Kinesis Spark Streaming documentation for more
* details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects.
* StorageLevel.MEMORY_AND_DISK_2 is recommended.
* @param messageHandler A custom message handler that can generate a generic output from a
* Kinesis `Record`, which contains both message data, and metadata.
*/
def createStream[T: ClassTag](
ssc: StreamingContext,
kinesisAppName: String,
streamName: String,
endpointUrl: String,
regionName: String,
initialPositionInStream: InitialPositionInStream,
checkpointInterval: Duration,
storageLevel: StorageLevel,
messageHandler: Record => T): ReceiverInputDStream[T] = {
val cleanedHandler = ssc.sc.clean(messageHandler)
// Setting scope to override receiver stream's scope of "receiver stream"
ssc.withNamedScope("kinesis stream") {
new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
cleanedHandler, None)
}
}
/**
* Create an input stream that pulls messages from a Kinesis stream.
* This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
*
* Note:
* The given AWS credentials will get saved in DStream checkpoints if checkpointing
* is enabled. Make sure that your checkpoint directory is secure.
*
* @param ssc StreamingContext object
* @param kinesisAppName Kinesis application name used by the Kinesis Client Library
* (KCL) to update DynamoDB
* @param streamName Kinesis stream name
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
* @param regionName Name of region used by the Kinesis Client Library (KCL) to update
* DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
* @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
* worker's initial starting position in the stream.
* The values are either the beginning of the stream
* per Kinesis' limit of 24 hours
* (InitialPositionInStream.TRIM_HORIZON) or
* the tip of the stream (InitialPositionInStream.LATEST).
* @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
* See the Kinesis Spark Streaming documentation for more
* details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects.
* StorageLevel.MEMORY_AND_DISK_2 is recommended.
* @param messageHandler A custom message handler that can generate a generic output from a
* Kinesis `Record`, which contains both message data, and metadata.
* @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
* @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
*/
// scalastyle:off
def createStream[T: ClassTag](
ssc: StreamingContext,
kinesisAppName: String,
streamName: String,
endpointUrl: String,
regionName: String,
initialPositionInStream: InitialPositionInStream,
checkpointInterval: Duration,
storageLevel: StorageLevel,
messageHandler: Record => T,
awsAccessKeyId: String,
awsSecretKey: String): ReceiverInputDStream[T] = {
// scalastyle:on
val cleanedHandler = ssc.sc.clean(messageHandler)
ssc.withNamedScope("kinesis stream") {
new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
}
}
/**
* Create an input stream that pulls messages from a Kinesis stream.
* This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
@ -61,12 +165,12 @@ object KinesisUtils {
regionName: String,
initialPositionInStream: InitialPositionInStream,
checkpointInterval: Duration,
storageLevel: StorageLevel
): ReceiverInputDStream[Array[Byte]] = {
storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = {
// Setting scope to override receiver stream's scope of "receiver stream"
ssc.withNamedScope("kinesis stream") {
new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, None)
new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
defaultMessageHandler, None)
}
}
@ -109,12 +213,11 @@ object KinesisUtils {
checkpointInterval: Duration,
storageLevel: StorageLevel,
awsAccessKeyId: String,
awsSecretKey: String
): ReceiverInputDStream[Array[Byte]] = {
awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = {
ssc.withNamedScope("kinesis stream") {
new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName),
new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
}
}
@ -156,11 +259,113 @@ object KinesisUtils {
storageLevel: StorageLevel
): ReceiverInputDStream[Array[Byte]] = {
ssc.withNamedScope("kinesis stream") {
new KinesisInputDStream(ssc, streamName, endpointUrl, getRegionByEndpoint(endpointUrl),
initialPositionInStream, ssc.sc.appName, checkpointInterval, storageLevel, None)
new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl,
getRegionByEndpoint(endpointUrl), initialPositionInStream, ssc.sc.appName,
checkpointInterval, storageLevel, defaultMessageHandler, None)
}
}
/**
* Create an input stream that pulls messages from a Kinesis stream.
* This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
*
* Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
* on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
* gets the AWS credentials.
*
* @param jssc Java StreamingContext object
* @param kinesisAppName Kinesis application name used by the Kinesis Client Library
* (KCL) to update DynamoDB
* @param streamName Kinesis stream name
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
* @param regionName Name of region used by the Kinesis Client Library (KCL) to update
* DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
* @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
* worker's initial starting position in the stream.
* The values are either the beginning of the stream
* per Kinesis' limit of 24 hours
* (InitialPositionInStream.TRIM_HORIZON) or
* the tip of the stream (InitialPositionInStream.LATEST).
* @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
* See the Kinesis Spark Streaming documentation for more
* details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects.
* StorageLevel.MEMORY_AND_DISK_2 is recommended.
* @param messageHandler A custom message handler that can generate a generic output from a
* Kinesis `Record`, which contains both message data, and metadata.
* @param recordClass Class of the records in DStream
*/
def createStream[T](
jssc: JavaStreamingContext,
kinesisAppName: String,
streamName: String,
endpointUrl: String,
regionName: String,
initialPositionInStream: InitialPositionInStream,
checkpointInterval: Duration,
storageLevel: StorageLevel,
messageHandler: JFunction[Record, T],
recordClass: Class[T]): JavaReceiverInputDStream[T] = {
implicit val recordCmt: ClassTag[T] = ClassTag(recordClass)
val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_))
createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler)
}
/**
* Create an input stream that pulls messages from a Kinesis stream.
* This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
*
* Note:
* The given AWS credentials will get saved in DStream checkpoints if checkpointing
* is enabled. Make sure that your checkpoint directory is secure.
*
* @param jssc Java StreamingContext object
* @param kinesisAppName Kinesis application name used by the Kinesis Client Library
* (KCL) to update DynamoDB
* @param streamName Kinesis stream name
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
* @param regionName Name of region used by the Kinesis Client Library (KCL) to update
* DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
* @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
* worker's initial starting position in the stream.
* The values are either the beginning of the stream
* per Kinesis' limit of 24 hours
* (InitialPositionInStream.TRIM_HORIZON) or
* the tip of the stream (InitialPositionInStream.LATEST).
* @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
* See the Kinesis Spark Streaming documentation for more
* details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects.
* StorageLevel.MEMORY_AND_DISK_2 is recommended.
* @param messageHandler A custom message handler that can generate a generic output from a
* Kinesis `Record`, which contains both message data, and metadata.
* @param recordClass Class of the records in DStream
* @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
* @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
*/
// scalastyle:off
def createStream[T](
jssc: JavaStreamingContext,
kinesisAppName: String,
streamName: String,
endpointUrl: String,
regionName: String,
initialPositionInStream: InitialPositionInStream,
checkpointInterval: Duration,
storageLevel: StorageLevel,
messageHandler: JFunction[Record, T],
recordClass: Class[T],
awsAccessKeyId: String,
awsSecretKey: String): JavaReceiverInputDStream[T] = {
// scalastyle:on
implicit val recordCmt: ClassTag[T] = ClassTag(recordClass)
val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_))
createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler,
awsAccessKeyId, awsSecretKey)
}
/**
* Create an input stream that pulls messages from a Kinesis stream.
* This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
@ -198,8 +403,8 @@ object KinesisUtils {
checkpointInterval: Duration,
storageLevel: StorageLevel
): JavaReceiverInputDStream[Array[Byte]] = {
createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
initialPositionInStream, checkpointInterval, storageLevel)
createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_))
}
/**
@ -241,10 +446,10 @@ object KinesisUtils {
checkpointInterval: Duration,
storageLevel: StorageLevel,
awsAccessKeyId: String,
awsSecretKey: String
): JavaReceiverInputDStream[Array[Byte]] = {
createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey)
awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = {
createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
initialPositionInStream, checkpointInterval, storageLevel,
defaultMessageHandler(_), awsAccessKeyId, awsSecretKey)
}
/**
@ -297,6 +502,14 @@ object KinesisUtils {
throw new IllegalArgumentException(s"Region name '$regionName' is not valid")
}
}
private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = {
if (record == null) return null
val byteBuffer = record.getData()
val byteArray = new Array[Byte](byteBuffer.remaining())
byteBuffer.get(byteArray)
byteArray
}
}
/**

View file

@ -17,14 +17,19 @@
package org.apache.spark.streaming.kinesis;
import com.amazonaws.services.kinesis.model.Record;
import org.junit.Test;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.LocalJavaStreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.junit.Test;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
import java.nio.ByteBuffer;
/**
* Demonstrate the use of the KinesisUtils Java API
*/
@ -33,9 +38,27 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext {
public void testKinesisStream() {
// Tests the API, does not actually test data receiving
JavaDStream<byte[]> kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream",
"https://kinesis.us-west-2.amazonaws.com", new Duration(2000),
"https://kinesis.us-west-2.amazonaws.com", new Duration(2000),
InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2());
ssc.stop();
}
private static Function<Record, String> handler = new Function<Record, String>() {
@Override
public String call(Record record) {
return record.getPartitionKey() + "-" + record.getSequenceNumber();
}
};
@Test
public void testCustomHandler() {
// Tests the API, does not actually test data receiving
JavaDStream<String> kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream",
"https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST,
new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class);
ssc.stop();
}
}

View file

@ -73,22 +73,22 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
testIfEnabled("Basic reading from Kinesis") {
// Verify all data using multiple ranges in a single RDD partition
val receivedData1 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
fakeBlockIds(1),
val receivedData1 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
testUtils.endpointUrl, fakeBlockIds(1),
Array(SequenceNumberRanges(allRanges.toArray))
).map { bytes => new String(bytes).toInt }.collect()
assert(receivedData1.toSet === testData.toSet)
// Verify all data using one range in each of the multiple RDD partitions
val receivedData2 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
fakeBlockIds(allRanges.size),
val receivedData2 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
testUtils.endpointUrl, fakeBlockIds(allRanges.size),
allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
).map { bytes => new String(bytes).toInt }.collect()
assert(receivedData2.toSet === testData.toSet)
// Verify ordering within each partition
val receivedData3 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
fakeBlockIds(allRanges.size),
val receivedData3 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
testUtils.endpointUrl, fakeBlockIds(allRanges.size),
allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
).map { bytes => new String(bytes).toInt }.collectPartitions()
assert(receivedData3.length === allRanges.size)
@ -209,7 +209,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
}, "Incorrect configuration of RDD, unexpected ranges set"
)
val rdd = new KinesisBackedBlockRDD(
val rdd = new KinesisBackedBlockRDD[Array[Byte]](
sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges)
val collectedData = rdd.map { bytes =>
new String(bytes).toInt
@ -223,7 +223,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
if (testIsBlockValid) {
require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager")
require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis")
val rdd2 = new KinesisBackedBlockRDD(
val rdd2 = new KinesisBackedBlockRDD[Array[Byte]](
sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges,
isBlockIdValid = Array.fill(blockIds.length)(false))
intercept[SparkException] {

View file

@ -52,14 +52,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8)))
val batch = Arrays.asList(record1, record2)
var receiverMock: KinesisReceiver = _
var receiverMock: KinesisReceiver[Array[Byte]] = _
var checkpointerMock: IRecordProcessorCheckpointer = _
var checkpointClockMock: ManualClock = _
var checkpointStateMock: KinesisCheckpointState = _
var currentClockMock: Clock = _
override def beforeFunction(): Unit = {
receiverMock = mock[KinesisReceiver]
receiverMock = mock[KinesisReceiver[Array[Byte]]]
checkpointerMock = mock[IRecordProcessorCheckpointer]
checkpointClockMock = mock[ManualClock]
checkpointStateMock = mock[KinesisCheckpointState]

View file

@ -24,6 +24,7 @@ import scala.util.Random
import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.model.Record
import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
@ -31,6 +32,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming._
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.kinesis.KinesisTestUtils._
import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
@ -113,9 +115,9 @@ class KinesisStreamSuite extends KinesisFunSuite
val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream",
dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2),
StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey)
assert(inputStream.isInstanceOf[KinesisInputDStream])
assert(inputStream.isInstanceOf[KinesisInputDStream[Array[Byte]]])
val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream]
val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream[Array[Byte]]]
val time = Time(1000)
// Generate block info data for testing
@ -134,8 +136,8 @@ class KinesisStreamSuite extends KinesisFunSuite
// Verify that the generated KinesisBackedBlockRDD has the all the right information
val blockInfos = Seq(blockInfo1, blockInfo2)
val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos)
nonEmptyRDD shouldBe a [KinesisBackedBlockRDD]
val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD]
nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
assert(kinesisRDD.regionName === dummyRegionName)
assert(kinesisRDD.endpointUrl === dummyEndpointUrl)
assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds)
@ -151,7 +153,7 @@ class KinesisStreamSuite extends KinesisFunSuite
// Verify that KinesisBackedBlockRDD is generated even when there are no blocks
val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty)
emptyRDD shouldBe a [KinesisBackedBlockRDD]
emptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
emptyRDD.partitions shouldBe empty
// Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid
@ -192,6 +194,32 @@ class KinesisStreamSuite extends KinesisFunSuite
ssc.stop(stopSparkContext = false)
}
testIfEnabled("custom message handling") {
val awsCredentials = KinesisTestUtils.getAWSCredentials()
def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5
val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
Seconds(10), StorageLevel.MEMORY_ONLY, addFive,
awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
stream shouldBe a [ReceiverInputDStream[Int]]
val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int]
stream.foreachRDD { rdd =>
collected ++= rdd.collect()
logInfo("Collected = " + rdd.collect().toSeq.mkString(", "))
}
ssc.start()
val testData = 1 to 10
eventually(timeout(120 seconds), interval(10 second)) {
testUtils.pushData(testData)
val modData = testData.map(_ + 5)
assert(collected === modData.toSet, "\nData received does not match data sent")
}
ssc.stop(stopSparkContext = false)
}
testIfEnabled("failure recovery") {
val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
val checkpointDir = Utils.createTempDir().getAbsolutePath
@ -210,7 +238,7 @@ class KinesisStreamSuite extends KinesisFunSuite
// Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch
kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => {
val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD]
val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq
collectedData(time) = (kRdd.arrayOfseqNumberRanges, data)
})
@ -243,10 +271,10 @@ class KinesisStreamSuite extends KinesisFunSuite
times.foreach { time =>
val (arrayOfSeqNumRanges, data) = collectedData(time)
val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]]
rdd shouldBe a [KinesisBackedBlockRDD]
rdd shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
// Verify the recovered sequence ranges
val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD]
val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size)
arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) =>
assert(expected.ranges.toSeq === found.ranges.toSeq)