[SPARK-19911][STREAMING] Add builder interface for Kinesis DStreams

## What changes were proposed in this pull request?

- Add new KinesisDStream.scala containing KinesisDStream.Builder class
- Add KinesisDStreamBuilderSuite test suite
- Make KinesisInputDStream ctor args package private for testing
- Add JavaKinesisDStreamBuilderSuite test suite
- Add args to KinesisInputDStream and KinesisReceiver for optional
  service-specific auth (Kinesis, DynamoDB and CloudWatch)
## How was this patch tested?

Added ```KinesisDStreamBuilderSuite``` to verify builder class works as expected

Author: Adam Budde <budde@amazon.com>

Closes #17250 from budde/KinesisStreamBuilder.
This commit is contained in:
Adam Budde 2017-03-24 12:40:29 -07:00 committed by Burak Yavuz
parent 9299d071f9
commit 707e501832
11 changed files with 749 additions and 149 deletions

View file

@ -82,8 +82,8 @@ class KinesisBackedBlockRDD[T: ClassTag](
@transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges],
@transient private val isBlockIdValid: Array[Boolean] = Array.empty,
val retryTimeoutMs: Int = 10000,
val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _,
val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider
val messageHandler: Record => T = KinesisInputDStream.defaultMessageHandler _,
val kinesisCreds: SparkAWSCredentials = DefaultCredentials
) extends BlockRDD[T](sc, _blockIds) {
require(_blockIds.length == arrayOfseqNumberRanges.length,
@ -109,7 +109,7 @@ class KinesisBackedBlockRDD[T: ClassTag](
}
def getBlockFromKinesis(): Iterator[T] = {
val credentials = kinesisCredsProvider.provider.getCredentials
val credentials = kinesisCreds.provider.getCredentials
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName,
range, retryTimeoutMs).map(messageHandler)

View file

@ -22,24 +22,28 @@ import scala.reflect.ClassTag
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.model.Record
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.streaming.{Duration, StreamingContext, Time}
import org.apache.spark.streaming.api.java.JavaStreamingContext
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
private[kinesis] class KinesisInputDStream[T: ClassTag](
_ssc: StreamingContext,
streamName: String,
endpointUrl: String,
regionName: String,
initialPositionInStream: InitialPositionInStream,
checkpointAppName: String,
checkpointInterval: Duration,
storageLevel: StorageLevel,
messageHandler: Record => T,
kinesisCredsProvider: SerializableCredentialsProvider
val streamName: String,
val endpointUrl: String,
val regionName: String,
val initialPositionInStream: InitialPositionInStream,
val checkpointAppName: String,
val checkpointInterval: Duration,
val _storageLevel: StorageLevel,
val messageHandler: Record => T,
val kinesisCreds: SparkAWSCredentials,
val dynamoDBCreds: Option[SparkAWSCredentials],
val cloudWatchCreds: Option[SparkAWSCredentials]
) extends ReceiverInputDStream[T](_ssc) {
private[streaming]
@ -61,7 +65,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
isBlockIdValid = isBlockIdValid,
retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt,
messageHandler = messageHandler,
kinesisCredsProvider = kinesisCredsProvider)
kinesisCreds = kinesisCreds)
} else {
logWarning("Kinesis sequence number information was not present with some block metadata," +
" it may not be possible to recover from failures")
@ -71,7 +75,238 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
override def getReceiver(): Receiver[T] = {
new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream,
checkpointAppName, checkpointInterval, storageLevel, messageHandler,
kinesisCredsProvider)
checkpointAppName, checkpointInterval, _storageLevel, messageHandler,
kinesisCreds, dynamoDBCreds, cloudWatchCreds)
}
}
@InterfaceStability.Evolving
object KinesisInputDStream {
/**
* Builder for [[KinesisInputDStream]] instances.
*
* @since 2.2.0
*/
@InterfaceStability.Evolving
class Builder {
// Required params
private var streamingContext: Option[StreamingContext] = None
private var streamName: Option[String] = None
private var checkpointAppName: Option[String] = None
// Params with defaults
private var endpointUrl: Option[String] = None
private var regionName: Option[String] = None
private var initialPositionInStream: Option[InitialPositionInStream] = None
private var checkpointInterval: Option[Duration] = None
private var storageLevel: Option[StorageLevel] = None
private var kinesisCredsProvider: Option[SparkAWSCredentials] = None
private var dynamoDBCredsProvider: Option[SparkAWSCredentials] = None
private var cloudWatchCredsProvider: Option[SparkAWSCredentials] = None
/**
* Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a
* required parameter.
*
* @param ssc [[StreamingContext]] used to construct Kinesis DStreams
* @return Reference to this [[KinesisInputDStream.Builder]]
*/
def streamingContext(ssc: StreamingContext): Builder = {
streamingContext = Option(ssc)
this
}
/**
* Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a
* required parameter.
*
* @param jssc [[JavaStreamingContext]] used to construct Kinesis DStreams
* @return Reference to this [[KinesisInputDStream.Builder]]
*/
def streamingContext(jssc: JavaStreamingContext): Builder = {
streamingContext = Option(jssc.ssc)
this
}
/**
* Sets the name of the Kinesis stream that the DStream will read from. This is a required
* parameter.
*
* @param streamName Name of Kinesis stream that the DStream will read from
* @return Reference to this [[KinesisInputDStream.Builder]]
*/
def streamName(streamName: String): Builder = {
this.streamName = Option(streamName)
this
}
/**
* Sets the KCL application name to use when checkpointing state to DynamoDB. This is a
* required parameter.
*
* @param appName Value to use for the KCL app name (used when creating the DynamoDB checkpoint
* table and when writing metrics to CloudWatch)
* @return Reference to this [[KinesisInputDStream.Builder]]
*/
def checkpointAppName(appName: String): Builder = {
checkpointAppName = Option(appName)
this
}
/**
* Sets the AWS Kinesis endpoint URL. Defaults to "https://kinesis.us-east-1.amazonaws.com" if
* no custom value is specified
*
* @param url Kinesis endpoint URL to use
* @return Reference to this [[KinesisInputDStream.Builder]]
*/
def endpointUrl(url: String): Builder = {
endpointUrl = Option(url)
this
}
/**
* Sets the AWS region to construct clients for. Defaults to "us-east-1" if no custom value
* is specified.
*
* @param regionName Name of AWS region to use (e.g. "us-west-2")
* @return Reference to this [[KinesisInputDStream.Builder]]
*/
def regionName(regionName: String): Builder = {
this.regionName = Option(regionName)
this
}
/**
* Sets the initial position data is read from in the Kinesis stream. Defaults to
* [[InitialPositionInStream.LATEST]] if no custom value is specified.
*
* @param initialPosition InitialPositionInStream value specifying where Spark Streaming
* will start reading records in the Kinesis stream from
* @return Reference to this [[KinesisInputDStream.Builder]]
*/
def initialPositionInStream(initialPosition: InitialPositionInStream): Builder = {
initialPositionInStream = Option(initialPosition)
this
}
/**
* Sets how often the KCL application state is checkpointed to DynamoDB. Defaults to the Spark
* Streaming batch interval if no custom value is specified.
*
* @param interval [[Duration]] specifying how often the KCL state should be checkpointed to
* DynamoDB.
* @return Reference to this [[KinesisInputDStream.Builder]]
*/
def checkpointInterval(interval: Duration): Builder = {
checkpointInterval = Option(interval)
this
}
/**
* Sets the storage level of the blocks for the DStream created. Defaults to
* [[StorageLevel.MEMORY_AND_DISK_2]] if no custom value is specified.
*
* @param storageLevel [[StorageLevel]] to use for the DStream data blocks
* @return Reference to this [[KinesisInputDStream.Builder]]
*/
def storageLevel(storageLevel: StorageLevel): Builder = {
this.storageLevel = Option(storageLevel)
this
}
/**
* Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS Kinesis
* endpoint. Defaults to [[DefaultCredentialsProvider]] if no custom value is specified.
*
* @param credentials [[SparkAWSCredentials]] to use for Kinesis authentication
*/
def kinesisCredentials(credentials: SparkAWSCredentials): Builder = {
kinesisCredsProvider = Option(credentials)
this
}
/**
* Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS DynamoDB
* endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set.
*
* @param credentials [[SparkAWSCredentials]] to use for DynamoDB authentication
*/
def dynamoDBCredentials(credentials: SparkAWSCredentials): Builder = {
dynamoDBCredsProvider = Option(credentials)
this
}
/**
* Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS CloudWatch
* endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set.
*
* @param credentials [[SparkAWSCredentials]] to use for CloudWatch authentication
*/
def cloudWatchCredentials(credentials: SparkAWSCredentials): Builder = {
cloudWatchCredsProvider = Option(credentials)
this
}
/**
* Create a new instance of [[KinesisInputDStream]] with configured parameters and the provided
* message handler.
*
* @param handler Function converting [[Record]] instances read by the KCL to DStream type [[T]]
* @return Instance of [[KinesisInputDStream]] constructed with configured parameters
*/
def buildWithMessageHandler[T: ClassTag](
handler: Record => T): KinesisInputDStream[T] = {
val ssc = getRequiredParam(streamingContext, "streamingContext")
new KinesisInputDStream(
ssc,
getRequiredParam(streamName, "streamName"),
endpointUrl.getOrElse(DEFAULT_KINESIS_ENDPOINT_URL),
regionName.getOrElse(DEFAULT_KINESIS_REGION_NAME),
initialPositionInStream.getOrElse(DEFAULT_INITIAL_POSITION_IN_STREAM),
getRequiredParam(checkpointAppName, "checkpointAppName"),
checkpointInterval.getOrElse(ssc.graph.batchDuration),
storageLevel.getOrElse(DEFAULT_STORAGE_LEVEL),
handler,
kinesisCredsProvider.getOrElse(DefaultCredentials),
dynamoDBCredsProvider,
cloudWatchCredsProvider)
}
/**
* Create a new instance of [[KinesisInputDStream]] with configured parameters and using the
* default message handler, which returns [[Array[Byte]]].
*
* @return Instance of [[KinesisInputDStream]] constructed with configured parameters
*/
def build(): KinesisInputDStream[Array[Byte]] = buildWithMessageHandler(defaultMessageHandler)
private def getRequiredParam[T](param: Option[T], paramName: String): T = param.getOrElse {
throw new IllegalArgumentException(s"No value provided for required parameter $paramName")
}
}
/**
* Creates a [[KinesisInputDStream.Builder]] for constructing [[KinesisInputDStream]] instances.
*
* @since 2.2.0
*
* @return [[KinesisInputDStream.Builder]] instance
*/
def builder: Builder = new Builder
private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = {
if (record == null) return null
val byteBuffer = record.getData()
val byteArray = new Array[Byte](byteBuffer.remaining())
byteBuffer.get(byteArray)
byteArray
}
private[kinesis] val DEFAULT_KINESIS_ENDPOINT_URL: String =
"https://kinesis.us-east-1.amazonaws.com"
private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1"
private[kinesis] val DEFAULT_INITIAL_POSITION_IN_STREAM: InitialPositionInStream =
InitialPositionInStream.LATEST
private[kinesis] val DEFAULT_STORAGE_LEVEL: StorageLevel = StorageLevel.MEMORY_AND_DISK_2
}

View file

@ -70,9 +70,14 @@ import org.apache.spark.util.Utils
* See the Kinesis Spark Streaming documentation for more
* details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects
* @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to
* generate the AWSCredentialsProvider instance used for KCL
* authorization.
* @param kinesisCreds SparkAWSCredentials instance that will be used to generate the
* AWSCredentialsProvider passed to the KCL to authorize Kinesis API calls.
* @param cloudWatchCreds Optional SparkAWSCredentials instance that will be used to generate the
* AWSCredentialsProvider passed to the KCL to authorize CloudWatch API
* calls. Will use kinesisCreds if value is None.
* @param dynamoDBCreds Optional SparkAWSCredentials instance that will be used to generate the
* AWSCredentialsProvider passed to the KCL to authorize DynamoDB API calls.
* Will use kinesisCreds if value is None.
*/
private[kinesis] class KinesisReceiver[T](
val streamName: String,
@ -83,7 +88,9 @@ private[kinesis] class KinesisReceiver[T](
checkpointInterval: Duration,
storageLevel: StorageLevel,
messageHandler: Record => T,
kinesisCredsProvider: SerializableCredentialsProvider)
kinesisCreds: SparkAWSCredentials,
dynamoDBCreds: Option[SparkAWSCredentials],
cloudWatchCreds: Option[SparkAWSCredentials])
extends Receiver[T](storageLevel) with Logging { receiver =>
/*
@ -140,10 +147,13 @@ private[kinesis] class KinesisReceiver[T](
workerId = Utils.localHostName() + ":" + UUID.randomUUID()
kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId)
val kinesisProvider = kinesisCreds.provider
val kinesisClientLibConfiguration = new KinesisClientLibConfiguration(
checkpointAppName,
streamName,
kinesisCredsProvider.provider,
kinesisProvider,
dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider),
cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider),
workerId)
.withKinesisEndpoint(endpointUrl)
.withInitialPositionInStream(initialPositionInStream)

View file

@ -58,6 +58,7 @@ object KinesisUtils {
* on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
* gets the AWS credentials.
*/
@deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
def createStream[T: ClassTag](
ssc: StreamingContext,
kinesisAppName: String,
@ -73,7 +74,7 @@ object KinesisUtils {
ssc.withNamedScope("kinesis stream") {
new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
cleanedHandler, DefaultCredentialsProvider)
cleanedHandler, DefaultCredentials, None, None)
}
}
@ -108,6 +109,7 @@ object KinesisUtils {
* is enabled. Make sure that your checkpoint directory is secure.
*/
// scalastyle:off
@deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
def createStream[T: ClassTag](
ssc: StreamingContext,
kinesisAppName: String,
@ -123,12 +125,12 @@ object KinesisUtils {
// scalastyle:on
val cleanedHandler = ssc.sc.clean(messageHandler)
ssc.withNamedScope("kinesis stream") {
val kinesisCredsProvider = BasicCredentialsProvider(
val kinesisCredsProvider = BasicCredentials(
awsAccessKeyId = awsAccessKeyId,
awsSecretKey = awsSecretKey)
new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
cleanedHandler, kinesisCredsProvider)
cleanedHandler, kinesisCredsProvider, None, None)
}
}
@ -169,6 +171,7 @@ object KinesisUtils {
* is enabled. Make sure that your checkpoint directory is secure.
*/
// scalastyle:off
@deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
def createStream[T: ClassTag](
ssc: StreamingContext,
kinesisAppName: String,
@ -187,16 +190,16 @@ object KinesisUtils {
// scalastyle:on
val cleanedHandler = ssc.sc.clean(messageHandler)
ssc.withNamedScope("kinesis stream") {
val kinesisCredsProvider = STSCredentialsProvider(
val kinesisCredsProvider = STSCredentials(
stsRoleArn = stsAssumeRoleArn,
stsSessionName = stsSessionName,
stsExternalId = Option(stsExternalId),
longLivedCredsProvider = BasicCredentialsProvider(
longLivedCreds = BasicCredentials(
awsAccessKeyId = awsAccessKeyId,
awsSecretKey = awsSecretKey))
new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
cleanedHandler, kinesisCredsProvider)
cleanedHandler, kinesisCredsProvider, None, None)
}
}
@ -227,6 +230,7 @@ object KinesisUtils {
* on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
* gets the AWS credentials.
*/
@deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
def createStream(
ssc: StreamingContext,
kinesisAppName: String,
@ -240,7 +244,7 @@ object KinesisUtils {
ssc.withNamedScope("kinesis stream") {
new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
defaultMessageHandler, DefaultCredentialsProvider)
KinesisInputDStream.defaultMessageHandler, DefaultCredentials, None, None)
}
}
@ -272,6 +276,7 @@ object KinesisUtils {
* @note The given AWS credentials will get saved in DStream checkpoints if checkpointing
* is enabled. Make sure that your checkpoint directory is secure.
*/
@deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
def createStream(
ssc: StreamingContext,
kinesisAppName: String,
@ -284,12 +289,12 @@ object KinesisUtils {
awsAccessKeyId: String,
awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = {
ssc.withNamedScope("kinesis stream") {
val kinesisCredsProvider = BasicCredentialsProvider(
val kinesisCredsProvider = BasicCredentials(
awsAccessKeyId = awsAccessKeyId,
awsSecretKey = awsSecretKey)
new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
defaultMessageHandler, kinesisCredsProvider)
KinesisInputDStream.defaultMessageHandler, kinesisCredsProvider, None, None)
}
}
@ -323,6 +328,7 @@ object KinesisUtils {
* on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
* gets the AWS credentials.
*/
@deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
def createStream[T](
jssc: JavaStreamingContext,
kinesisAppName: String,
@ -372,6 +378,7 @@ object KinesisUtils {
* is enabled. Make sure that your checkpoint directory is secure.
*/
// scalastyle:off
@deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
def createStream[T](
jssc: JavaStreamingContext,
kinesisAppName: String,
@ -431,6 +438,7 @@ object KinesisUtils {
* is enabled. Make sure that your checkpoint directory is secure.
*/
// scalastyle:off
@deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
def createStream[T](
jssc: JavaStreamingContext,
kinesisAppName: String,
@ -482,6 +490,7 @@ object KinesisUtils {
* on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
* gets the AWS credentials.
*/
@deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
def createStream(
jssc: JavaStreamingContext,
kinesisAppName: String,
@ -493,7 +502,8 @@ object KinesisUtils {
storageLevel: StorageLevel
): JavaReceiverInputDStream[Array[Byte]] = {
createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_))
initialPositionInStream, checkpointInterval, storageLevel,
KinesisInputDStream.defaultMessageHandler(_))
}
/**
@ -524,6 +534,7 @@ object KinesisUtils {
* @note The given AWS credentials will get saved in DStream checkpoints if checkpointing
* is enabled. Make sure that your checkpoint directory is secure.
*/
@deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
def createStream(
jssc: JavaStreamingContext,
kinesisAppName: String,
@ -537,7 +548,7 @@ object KinesisUtils {
awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = {
createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
initialPositionInStream, checkpointInterval, storageLevel,
defaultMessageHandler(_), awsAccessKeyId, awsSecretKey)
KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey)
}
private def validateRegion(regionName: String): String = {
@ -545,14 +556,6 @@ object KinesisUtils {
throw new IllegalArgumentException(s"Region name '$regionName' is not valid")
}
}
private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = {
if (record == null) return null
val byteBuffer = record.getData()
val byteArray = new Array[Byte](byteBuffer.remaining())
byteBuffer.get(byteArray)
byteArray
}
}
/**
@ -597,7 +600,7 @@ private class KinesisUtilsPythonHelper {
validateAwsCreds(awsAccessKeyId, awsSecretKey)
KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel,
KinesisUtils.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey,
KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey,
stsAssumeRoleArn, stsSessionName, stsExternalId)
} else {
validateAwsCreds(awsAccessKeyId, awsSecretKey)

View file

@ -1,85 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.kinesis
import scala.collection.JavaConverters._
import com.amazonaws.auth._
import org.apache.spark.internal.Logging
/**
* Serializable interface providing a method executors can call to obtain an
* AWSCredentialsProvider instance for authenticating to AWS services.
*/
private[kinesis] sealed trait SerializableCredentialsProvider extends Serializable {
/**
* Return an AWSCredentialProvider instance that can be used by the Kinesis Client
* Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB).
*/
def provider: AWSCredentialsProvider
}
/** Returns DefaultAWSCredentialsProviderChain for authentication. */
private[kinesis] final case object DefaultCredentialsProvider
extends SerializableCredentialsProvider {
def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain
}
/**
* Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using
* DefaultAWSCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain
* instance with the provided arguments (e.g. if they are null).
*/
private[kinesis] final case class BasicCredentialsProvider(
awsAccessKeyId: String,
awsSecretKey: String) extends SerializableCredentialsProvider with Logging {
def provider: AWSCredentialsProvider = try {
new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey))
} catch {
case e: IllegalArgumentException =>
logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " +
"falling back to DefaultAWSCredentialsProviderChain.", e)
new DefaultAWSCredentialsProviderChain
}
}
/**
* Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM
* role in order to authenticate against resources in an external account.
*/
private[kinesis] final case class STSCredentialsProvider(
stsRoleArn: String,
stsSessionName: String,
stsExternalId: Option[String] = None,
longLivedCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider)
extends SerializableCredentialsProvider {
def provider: AWSCredentialsProvider = {
val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName)
.withLongLivedCredentialsProvider(longLivedCredsProvider.provider)
stsExternalId match {
case Some(stsExternalId) =>
builder.withExternalId(stsExternalId)
.build()
case None =>
builder.build()
}
}
}

View file

@ -0,0 +1,182 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.kinesis
import scala.collection.JavaConverters._
import com.amazonaws.auth._
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.internal.Logging
/**
* Serializable interface providing a method executors can call to obtain an
* AWSCredentialsProvider instance for authenticating to AWS services.
*/
private[kinesis] sealed trait SparkAWSCredentials extends Serializable {
/**
* Return an AWSCredentialProvider instance that can be used by the Kinesis Client
* Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB).
*/
def provider: AWSCredentialsProvider
}
/** Returns DefaultAWSCredentialsProviderChain for authentication. */
private[kinesis] final case object DefaultCredentials extends SparkAWSCredentials {
def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain
}
/**
* Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using
* DefaultCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain
* instance with the provided arguments (e.g. if they are null).
*/
private[kinesis] final case class BasicCredentials(
awsAccessKeyId: String,
awsSecretKey: String) extends SparkAWSCredentials with Logging {
def provider: AWSCredentialsProvider = try {
new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey))
} catch {
case e: IllegalArgumentException =>
logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " +
"falling back to DefaultCredentialsProviderChain.", e)
new DefaultAWSCredentialsProviderChain
}
}
/**
* Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM
* role in order to authenticate against resources in an external account.
*/
private[kinesis] final case class STSCredentials(
stsRoleArn: String,
stsSessionName: String,
stsExternalId: Option[String] = None,
longLivedCreds: SparkAWSCredentials = DefaultCredentials)
extends SparkAWSCredentials {
def provider: AWSCredentialsProvider = {
val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName)
.withLongLivedCredentialsProvider(longLivedCreds.provider)
stsExternalId match {
case Some(stsExternalId) =>
builder.withExternalId(stsExternalId)
.build()
case None =>
builder.build()
}
}
}
@InterfaceStability.Evolving
object SparkAWSCredentials {
/**
* Builder for [[SparkAWSCredentials]] instances.
*
* @since 2.2.0
*/
@InterfaceStability.Evolving
class Builder {
private var basicCreds: Option[BasicCredentials] = None
private var stsCreds: Option[STSCredentials] = None
// scalastyle:off
/**
* Use a basic AWS keypair for long-lived authorization.
*
* @note The given AWS keypair will be saved in DStream checkpoints if checkpointing is
* enabled. Make sure that your checkpoint directory is secure. Prefer using the
* [[http://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html#credentials-default default provider chain]]
* instead if possible.
*
* @param accessKeyId AWS access key ID
* @param secretKey AWS secret key
* @return Reference to this [[SparkAWSCredentials.Builder]]
*/
// scalastyle:on
def basicCredentials(accessKeyId: String, secretKey: String): Builder = {
basicCreds = Option(BasicCredentials(
awsAccessKeyId = accessKeyId,
awsSecretKey = secretKey))
this
}
/**
* Use STS to assume an IAM role for temporary session-based authentication. Will use configured
* long-lived credentials for authorizing to STS itself (either the default provider chain
* or a configured keypair).
*
* @param roleArn ARN of IAM role to assume via STS
* @param sessionName Name to use for the STS session
* @return Reference to this [[SparkAWSCredentials.Builder]]
*/
def stsCredentials(roleArn: String, sessionName: String): Builder = {
stsCreds = Option(STSCredentials(stsRoleArn = roleArn, stsSessionName = sessionName))
this
}
/**
* Use STS to assume an IAM role for temporary session-based authentication. Will use configured
* long-lived credentials for authorizing to STS itself (either the default provider chain
* or a configured keypair). STS will validate the provided external ID with the one defined
* in the trust policy of the IAM role to be assumed (if one is present).
*
* @param roleArn ARN of IAM role to assume via STS
* @param sessionName Name to use for the STS session
* @param externalId External ID to validate against assumed IAM role's trust policy
* @return Reference to this [[SparkAWSCredentials.Builder]]
*/
def stsCredentials(roleArn: String, sessionName: String, externalId: String): Builder = {
stsCreds = Option(STSCredentials(
stsRoleArn = roleArn,
stsSessionName = sessionName,
stsExternalId = Option(externalId)))
this
}
/**
* Returns the appropriate instance of [[SparkAWSCredentials]] given the configured
* parameters.
*
* - The long-lived credentials will either be [[DefaultCredentials]] or [[BasicCredentials]]
* if they were provided.
*
* - If STS credentials were provided, the configured long-lived credentials will be added to
* them and the result will be returned.
*
* - The long-lived credentials will be returned otherwise.
*
* @return [[SparkAWSCredentials]] to use for configured parameters
*/
def build(): SparkAWSCredentials =
stsCreds.map(_.copy(longLivedCreds = longLivedCreds)).getOrElse(longLivedCreds)
private def longLivedCreds: SparkAWSCredentials = basicCreds.getOrElse(DefaultCredentials)
}
/**
* Creates a [[SparkAWSCredentials.Builder]] for constructing
* [[SparkAWSCredentials]] instances.
*
* @since 2.2.0
*
* @return [[SparkAWSCredentials.Builder]] instance
*/
def builder: Builder = new Builder
}

View file

@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.kinesis;
import org.junit.Test;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.Seconds;
import org.apache.spark.streaming.LocalJavaStreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
public class JavaKinesisInputDStreamBuilderSuite extends LocalJavaStreamingContext {
/**
* Basic test to ensure that the KinesisDStream.Builder interface is accessible from Java.
*/
@Test
public void testJavaKinesisDStreamBuilder() {
String streamName = "a-very-nice-stream-name";
String endpointUrl = "https://kinesis.us-west-2.amazonaws.com";
String region = "us-west-2";
InitialPositionInStream initialPosition = InitialPositionInStream.TRIM_HORIZON;
String appName = "a-very-nice-kinesis-app";
Duration checkpointInterval = Seconds.apply(30);
StorageLevel storageLevel = StorageLevel.MEMORY_ONLY();
KinesisInputDStream<byte[]> kinesisDStream = KinesisInputDStream.builder()
.streamingContext(ssc)
.streamName(streamName)
.endpointUrl(endpointUrl)
.regionName(region)
.initialPositionInStream(initialPosition)
.checkpointAppName(appName)
.checkpointInterval(checkpointInterval)
.storageLevel(storageLevel)
.build();
assert(kinesisDStream.streamName() == streamName);
assert(kinesisDStream.endpointUrl() == endpointUrl);
assert(kinesisDStream.regionName() == region);
assert(kinesisDStream.initialPositionInStream() == initialPosition);
assert(kinesisDStream.checkpointAppName() == appName);
assert(kinesisDStream.checkpointInterval() == checkpointInterval);
assert(kinesisDStream._storageLevel() == storageLevel);
ssc.stop();
}
}

View file

@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.kinesis
import java.lang.IllegalArgumentException
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import org.scalatest.BeforeAndAfterEach
import org.scalatest.mock.MockitoSugar
import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext, TestSuiteBase}
class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterEach
with MockitoSugar {
import KinesisInputDStream._
private val ssc = new StreamingContext(conf, batchDuration)
private val streamName = "a-very-nice-kinesis-stream-name"
private val checkpointAppName = "a-very-nice-kcl-app-name"
private def baseBuilder = KinesisInputDStream.builder
private def builder = baseBuilder.streamingContext(ssc)
.streamName(streamName)
.checkpointAppName(checkpointAppName)
override def afterAll(): Unit = {
ssc.stop()
}
test("should raise an exception if the StreamingContext is missing") {
intercept[IllegalArgumentException] {
baseBuilder.streamName(streamName).checkpointAppName(checkpointAppName).build()
}
}
test("should raise an exception if the stream name is missing") {
intercept[IllegalArgumentException] {
baseBuilder.streamingContext(ssc).checkpointAppName(checkpointAppName).build()
}
}
test("should raise an exception if the checkpoint app name is missing") {
intercept[IllegalArgumentException] {
baseBuilder.streamingContext(ssc).streamName(streamName).build()
}
}
test("should propagate required values to KinesisInputDStream") {
val dstream = builder.build()
assert(dstream.context == ssc)
assert(dstream.streamName == streamName)
assert(dstream.checkpointAppName == checkpointAppName)
}
test("should propagate default values to KinesisInputDStream") {
val dstream = builder.build()
assert(dstream.endpointUrl == DEFAULT_KINESIS_ENDPOINT_URL)
assert(dstream.regionName == DEFAULT_KINESIS_REGION_NAME)
assert(dstream.initialPositionInStream == DEFAULT_INITIAL_POSITION_IN_STREAM)
assert(dstream.checkpointInterval == batchDuration)
assert(dstream._storageLevel == DEFAULT_STORAGE_LEVEL)
assert(dstream.kinesisCreds == DefaultCredentials)
assert(dstream.dynamoDBCreds == None)
assert(dstream.cloudWatchCreds == None)
}
test("should propagate custom non-auth values to KinesisInputDStream") {
val customEndpointUrl = "https://kinesis.us-west-2.amazonaws.com"
val customRegion = "us-west-2"
val customInitialPosition = InitialPositionInStream.TRIM_HORIZON
val customAppName = "a-very-nice-kinesis-app"
val customCheckpointInterval = Seconds(30)
val customStorageLevel = StorageLevel.MEMORY_ONLY
val customKinesisCreds = mock[SparkAWSCredentials]
val customDynamoDBCreds = mock[SparkAWSCredentials]
val customCloudWatchCreds = mock[SparkAWSCredentials]
val dstream = builder
.endpointUrl(customEndpointUrl)
.regionName(customRegion)
.initialPositionInStream(customInitialPosition)
.checkpointAppName(customAppName)
.checkpointInterval(customCheckpointInterval)
.storageLevel(customStorageLevel)
.kinesisCredentials(customKinesisCreds)
.dynamoDBCredentials(customDynamoDBCreds)
.cloudWatchCredentials(customCloudWatchCreds)
.build()
assert(dstream.endpointUrl == customEndpointUrl)
assert(dstream.regionName == customRegion)
assert(dstream.initialPositionInStream == customInitialPosition)
assert(dstream.checkpointAppName == customAppName)
assert(dstream.checkpointInterval == customCheckpointInterval)
assert(dstream._storageLevel == customStorageLevel)
assert(dstream.kinesisCreds == customKinesisCreds)
assert(dstream.dynamoDBCreds == Option(customDynamoDBCreds))
assert(dstream.cloudWatchCreds == Option(customCloudWatchCreds))
}
}

View file

@ -31,7 +31,6 @@ import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.mock.MockitoSugar
import org.apache.spark.streaming.{Duration, TestSuiteBase}
import org.apache.spark.util.Utils
/**
* Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor
@ -62,28 +61,6 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
checkpointerMock = mock[IRecordProcessorCheckpointer]
}
test("check serializability of credential provider classes") {
Utils.deserialize[BasicCredentialsProvider](
Utils.serialize(BasicCredentialsProvider(
awsAccessKeyId = "x",
awsSecretKey = "y")))
Utils.deserialize[STSCredentialsProvider](
Utils.serialize(STSCredentialsProvider(
stsRoleArn = "fakeArn",
stsSessionName = "fakeSessionName",
stsExternalId = Some("fakeExternalId"))))
Utils.deserialize[STSCredentialsProvider](
Utils.serialize(STSCredentialsProvider(
stsRoleArn = "fakeArn",
stsSessionName = "fakeSessionName",
stsExternalId = Some("fakeExternalId"),
longLivedCredsProvider = BasicCredentialsProvider(
awsAccessKeyId = "x",
awsSecretKey = "y"))))
}
test("process records including store and set checkpointer") {
when(receiverMock.isStopped()).thenReturn(false)
when(receiverMock.getCurrentLimit).thenReturn(Int.MaxValue)

View file

@ -138,7 +138,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
assert(kinesisRDD.regionName === dummyRegionName)
assert(kinesisRDD.endpointUrl === dummyEndpointUrl)
assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds)
assert(kinesisRDD.kinesisCredsProvider === BasicCredentialsProvider(
assert(kinesisRDD.kinesisCreds === BasicCredentials(
awsAccessKeyId = dummyAWSAccessKey,
awsSecretKey = dummyAWSSecretKey))
assert(nonEmptyRDD.partitions.size === blockInfos.size)

View file

@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.kinesis
import org.apache.spark.streaming.TestSuiteBase
import org.apache.spark.util.Utils
class SparkAWSCredentialsBuilderSuite extends TestSuiteBase {
private def builder = SparkAWSCredentials.builder
private val basicCreds = BasicCredentials(
awsAccessKeyId = "a-very-nice-access-key",
awsSecretKey = "a-very-nice-secret-key")
private val stsCreds = STSCredentials(
stsRoleArn = "a-very-nice-role-arn",
stsSessionName = "a-very-nice-secret-key",
stsExternalId = Option("a-very-nice-external-id"),
longLivedCreds = basicCreds)
test("should build DefaultCredentials when given no params") {
assert(builder.build() == DefaultCredentials)
}
test("should build BasicCredentials") {
assertResult(basicCreds) {
builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey)
.build()
}
}
test("should build STSCredentials") {
// No external ID, default long-lived creds
assertResult(stsCreds.copy(stsExternalId = None, longLivedCreds = DefaultCredentials)) {
builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName)
.build()
}
// Default long-lived creds
assertResult(stsCreds.copy(longLivedCreds = DefaultCredentials)) {
builder.stsCredentials(
stsCreds.stsRoleArn,
stsCreds.stsSessionName,
stsCreds.stsExternalId.get)
.build()
}
// No external ID, basic keypair for long-lived creds
assertResult(stsCreds.copy(stsExternalId = None)) {
builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName)
.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey)
.build()
}
// Basic keypair for long-lived creds
assertResult(stsCreds) {
builder.stsCredentials(
stsCreds.stsRoleArn,
stsCreds.stsSessionName,
stsCreds.stsExternalId.get)
.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey)
.build()
}
// Order shouldn't matter
assertResult(stsCreds) {
builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey)
.stsCredentials(
stsCreds.stsRoleArn,
stsCreds.stsSessionName,
stsCreds.stsExternalId.get)
.build()
}
}
test("SparkAWSCredentials classes should be serializable") {
assertResult(basicCreds) {
Utils.deserialize[BasicCredentials](Utils.serialize(basicCreds))
}
assertResult(stsCreds) {
Utils.deserialize[STSCredentials](Utils.serialize(stsCreds))
}
// Will also test if DefaultCredentials can be serialized
val stsDefaultCreds = stsCreds.copy(longLivedCreds = DefaultCredentials)
assertResult(stsDefaultCreds) {
Utils.deserialize[STSCredentials](Utils.serialize(stsDefaultCreds))
}
}
}