[SPARK-32835][PYTHON] Add withField method to the pyspark Column class

### What changes were proposed in this pull request?

This PR adds a `withField` method on the pyspark Column class to call the Scala API method added in https://github.com/apache/spark/pull/27066.

### Why are the changes needed?

To update the Python API to match a new feature in the Scala API.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New unit test

Closes #29699 from Kimahriman/feature/pyspark-with-field.

Authored-by: Adam Binford <adam.binford@radiantsolutions.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Adam Binford 2020-09-16 20:18:36 +09:00 committed by HyukjinKwon
parent 550c1c9cfb
commit e884290587
2 changed files with 45 additions and 0 deletions

View file

@ -329,6 +329,35 @@ class Column(object):
DeprecationWarning) DeprecationWarning)
return self[name] return self[name]
@since(3.1)
def withField(self, fieldName, col):
"""
An expression that adds/replaces a field in :class:`StructType` by name.
>>> from pyspark.sql import Row
>>> from pyspark.sql.functions import lit
>>> df = spark.createDataFrame([Row(a=Row(b=1, c=2))])
>>> df.withColumn('a', df['a'].withField('b', lit(3))).select('a.b').show()
+---+
| b|
+---+
| 3|
+---+
>>> df.withColumn('a', df['a'].withField('d', lit(4))).select('a.d').show()
+---+
| d|
+---+
| 4|
+---+
"""
if not isinstance(fieldName, str):
raise TypeError("fieldName should be a string")
if not isinstance(col, Column):
raise TypeError("col should be a Column")
return Column(self._jc.withField(fieldName, col._jc))
def __getattr__(self, item): def __getattr__(self, item):
if item.startswith("__"): if item.startswith("__"):
raise AttributeError(item) raise AttributeError(item)

View file

@ -139,6 +139,22 @@ class ColumnTests(ReusedSQLTestCase):
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
self.assertEqual(~75, result['~b']) self.assertEqual(~75, result['~b'])
def test_with_field(self):
from pyspark.sql.functions import lit, col
df = self.spark.createDataFrame([Row(a=Row(b=1, c=2))])
self.assertIsInstance(df['a'].withField('b', lit(3)), Column)
self.assertIsInstance(df['a'].withField('d', lit(3)), Column)
result = df.withColumn('a', df['a'].withField('d', lit(3))).collect()[0].asDict()
self.assertEqual(3, result['a']['d'])
result = df.withColumn('a', df['a'].withField('b', lit(3))).collect()[0].asDict()
self.assertEqual(3, result['a']['b'])
self.assertRaisesRegex(TypeError,
'col should be a Column',
lambda: df['a'].withField('b', 3))
self.assertRaisesRegex(TypeError,
'fieldName should be a string',
lambda: df['a'].withField(col('b'), lit(3)))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest