[SPARK-32835][PYTHON] Add withField method to the pyspark Column class
### What changes were proposed in this pull request? This PR adds a `withField` method on the pyspark Column class to call the Scala API method added in https://github.com/apache/spark/pull/27066. ### Why are the changes needed? To update the Python API to match a new feature in the Scala API. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit test Closes #29699 from Kimahriman/feature/pyspark-with-field. Authored-by: Adam Binford <adam.binford@radiantsolutions.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
550c1c9cfb
commit
e884290587
|
@ -329,6 +329,35 @@ class Column(object):
|
||||||
DeprecationWarning)
|
DeprecationWarning)
|
||||||
return self[name]
|
return self[name]
|
||||||
|
|
||||||
|
@since(3.1)
|
||||||
|
def withField(self, fieldName, col):
|
||||||
|
"""
|
||||||
|
An expression that adds/replaces a field in :class:`StructType` by name.
|
||||||
|
|
||||||
|
>>> from pyspark.sql import Row
|
||||||
|
>>> from pyspark.sql.functions import lit
|
||||||
|
>>> df = spark.createDataFrame([Row(a=Row(b=1, c=2))])
|
||||||
|
>>> df.withColumn('a', df['a'].withField('b', lit(3))).select('a.b').show()
|
||||||
|
+---+
|
||||||
|
| b|
|
||||||
|
+---+
|
||||||
|
| 3|
|
||||||
|
+---+
|
||||||
|
>>> df.withColumn('a', df['a'].withField('d', lit(4))).select('a.d').show()
|
||||||
|
+---+
|
||||||
|
| d|
|
||||||
|
+---+
|
||||||
|
| 4|
|
||||||
|
+---+
|
||||||
|
"""
|
||||||
|
if not isinstance(fieldName, str):
|
||||||
|
raise TypeError("fieldName should be a string")
|
||||||
|
|
||||||
|
if not isinstance(col, Column):
|
||||||
|
raise TypeError("col should be a Column")
|
||||||
|
|
||||||
|
return Column(self._jc.withField(fieldName, col._jc))
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
if item.startswith("__"):
|
if item.startswith("__"):
|
||||||
raise AttributeError(item)
|
raise AttributeError(item)
|
||||||
|
|
|
@ -139,6 +139,22 @@ class ColumnTests(ReusedSQLTestCase):
|
||||||
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
|
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
|
||||||
self.assertEqual(~75, result['~b'])
|
self.assertEqual(~75, result['~b'])
|
||||||
|
|
||||||
|
def test_with_field(self):
|
||||||
|
from pyspark.sql.functions import lit, col
|
||||||
|
df = self.spark.createDataFrame([Row(a=Row(b=1, c=2))])
|
||||||
|
self.assertIsInstance(df['a'].withField('b', lit(3)), Column)
|
||||||
|
self.assertIsInstance(df['a'].withField('d', lit(3)), Column)
|
||||||
|
result = df.withColumn('a', df['a'].withField('d', lit(3))).collect()[0].asDict()
|
||||||
|
self.assertEqual(3, result['a']['d'])
|
||||||
|
result = df.withColumn('a', df['a'].withField('b', lit(3))).collect()[0].asDict()
|
||||||
|
self.assertEqual(3, result['a']['b'])
|
||||||
|
|
||||||
|
self.assertRaisesRegex(TypeError,
|
||||||
|
'col should be a Column',
|
||||||
|
lambda: df['a'].withField('b', 3))
|
||||||
|
self.assertRaisesRegex(TypeError,
|
||||||
|
'fieldName should be a string',
|
||||||
|
lambda: df['a'].withField(col('b'), lit(3)))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import unittest
|
import unittest
|
||||||
|
|
Loading…
Reference in a new issue