From e88429058723572b95502fd369f7c2c609c561e6 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Wed, 16 Sep 2020 20:18:36 +0900 Subject: [PATCH] [SPARK-32835][PYTHON] Add withField method to the pyspark Column class ### What changes were proposed in this pull request? This PR adds a `withField` method on the pyspark Column class to call the Scala API method added in https://github.com/apache/spark/pull/27066. ### Why are the changes needed? To update the Python API to match a new feature in the Scala API. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit test Closes #29699 from Kimahriman/feature/pyspark-with-field. Authored-by: Adam Binford Signed-off-by: HyukjinKwon --- python/pyspark/sql/column.py | 29 +++++++++++++++++++++++++ python/pyspark/sql/tests/test_column.py | 16 ++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 8c08d5cfa6..0e073d2a5d 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -329,6 +329,35 @@ class Column(object): DeprecationWarning) return self[name] + @since(3.1) + def withField(self, fieldName, col): + """ + An expression that adds/replaces a field in :class:`StructType` by name. + + >>> from pyspark.sql import Row + >>> from pyspark.sql.functions import lit + >>> df = spark.createDataFrame([Row(a=Row(b=1, c=2))]) + >>> df.withColumn('a', df['a'].withField('b', lit(3))).select('a.b').show() + +---+ + | b| + +---+ + | 3| + +---+ + >>> df.withColumn('a', df['a'].withField('d', lit(4))).select('a.d').show() + +---+ + | d| + +---+ + | 4| + +---+ + """ + if not isinstance(fieldName, str): + raise TypeError("fieldName should be a string") + + if not isinstance(col, Column): + raise TypeError("col should be a Column") + + return Column(self._jc.withField(fieldName, col._jc)) + def __getattr__(self, item): if item.startswith("__"): raise AttributeError(item) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 5e05a8b63b..8a89e6e9d5 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -139,6 +139,22 @@ class ColumnTests(ReusedSQLTestCase): result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() self.assertEqual(~75, result['~b']) + def test_with_field(self): + from pyspark.sql.functions import lit, col + df = self.spark.createDataFrame([Row(a=Row(b=1, c=2))]) + self.assertIsInstance(df['a'].withField('b', lit(3)), Column) + self.assertIsInstance(df['a'].withField('d', lit(3)), Column) + result = df.withColumn('a', df['a'].withField('d', lit(3))).collect()[0].asDict() + self.assertEqual(3, result['a']['d']) + result = df.withColumn('a', df['a'].withField('b', lit(3))).collect()[0].asDict() + self.assertEqual(3, result['a']['b']) + + self.assertRaisesRegex(TypeError, + 'col should be a Column', + lambda: df['a'].withField('b', 3)) + self.assertRaisesRegex(TypeError, + 'fieldName should be a string', + lambda: df['a'].withField(col('b'), lit(3))) if __name__ == "__main__": import unittest