diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6137ed25a0..64178eb7b5 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -115,6 +115,11 @@ class SparkContext(object): ValueError:... """ self._callsite = first_spark_call() or CallSite(None, None, None) + if gateway is not None and gateway.gateway_parameters.auth_token is None: + raise ValueError( + "You are trying to pass an insecure Py4j gateway to Spark. This" + " is not allowed as it is a security risk.") + SparkContext._ensure_initialized(self, gateway=gateway, conf=conf) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index 201baf4203..18d9cd40be 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -20,6 +20,7 @@ import tempfile import threading import time import unittest +from collections import namedtuple from pyspark import SparkFiles, SparkContext from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME @@ -246,6 +247,15 @@ class ContextTests(unittest.TestCase): with SparkContext() as sc: self.assertGreater(sc.startTime, 0) + def test_forbid_insecure_gateway(self): + # Fail immediately if you try to create a SparkContext + # with an insecure gateway + parameters = namedtuple('MockGatewayParameters', 'auth_token')(None) + mock_insecure_gateway = namedtuple('MockJavaGateway', 'gateway_parameters')(parameters) + with self.assertRaises(ValueError) as context: + SparkContext(gateway=mock_insecure_gateway) + self.assertIn("insecure Py4j gateway", str(context.exception)) + if __name__ == "__main__": from pyspark.tests.test_context import *