[SPARK-23706][PYTHON] spark.conf.get(value, default=None) should produce None in PySpark

## What changes were proposed in this pull request?

Scala:

```
scala> spark.conf.get("hey", null)
res1: String = null
```

```
scala> spark.conf.get("spark.sql.sources.partitionOverwriteMode", null)
res2: String = null
```

Python:

**Before**

```
>>> spark.conf.get("hey", None)
...
py4j.protocol.Py4JJavaError: An error occurred while calling o30.get.
: java.util.NoSuchElementException: hey
...
```

```
>>> spark.conf.get("spark.sql.sources.partitionOverwriteMode", None)
u'STATIC'
```

**After**

```
>>> spark.conf.get("hey", None) is None
True
```

```
>>> spark.conf.get("spark.sql.sources.partitionOverwriteMode", None) is None
True
```

*Note that this PR preserves the case below:

```
>>> spark.conf.get("spark.sql.sources.partitionOverwriteMode")
u'STATIC'
```

## How was this patch tested?

Manually tested and unit tests were added.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #20841 from HyukjinKwon/spark-conf-get.
This commit is contained in:
hyukjinkwon 2018-03-18 20:24:14 +09:00
parent 8a1efe3076
commit 61487b308b
3 changed files with 20 additions and 8 deletions

View file

@ -17,7 +17,7 @@
import sys
from pyspark import since
from pyspark import since, _NoValue
from pyspark.rdd import ignore_unicode_prefix
@ -39,15 +39,16 @@ class RuntimeConfig(object):
@ignore_unicode_prefix
@since(2.0)
def get(self, key, default=None):
def get(self, key, default=_NoValue):
"""Returns the value of Spark runtime configuration property for the given key,
assuming it is set.
"""
self._checkType(key, "key")
if default is None:
if default is _NoValue:
return self._jconf.get(key)
else:
self._checkType(default, "default")
if default is not None:
self._checkType(default, "default")
return self._jconf.get(key, default)
@ignore_unicode_prefix

View file

@ -22,7 +22,7 @@ import warnings
if sys.version >= '3':
basestring = unicode = str
from pyspark import since
from pyspark import since, _NoValue
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.session import _monkey_patch_RDD, SparkSession
from pyspark.sql.dataframe import DataFrame
@ -124,11 +124,11 @@ class SQLContext(object):
@ignore_unicode_prefix
@since(1.3)
def getConf(self, key, defaultValue=None):
def getConf(self, key, defaultValue=_NoValue):
"""Returns the value of Spark SQL configuration property for the given key.
If the key is not set and defaultValue is not None, return
defaultValue. If the key is not set and defaultValue is None, return
If the key is not set and defaultValue is set, return
defaultValue. If the key is not set and defaultValue is not set, return
the system default value.
>>> sqlContext.getConf("spark.sql.shuffle.partitions")

View file

@ -2504,6 +2504,17 @@ class SQLTests(ReusedSQLTestCase):
spark.conf.unset("bogo")
self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
self.assertEqual(spark.conf.get("hyukjin", None), None)
# This returns 'STATIC' because it's the default value of
# 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in
# `spark.conf.get` is unset.
self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC")
# This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but
# `defaultValue` in `spark.conf.get` is set to None.
self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None)
def test_current_database(self):
spark = self.spark
spark.catalog._reset()