[SPARK-6876] [PySpark] [SQL] add DataFrame na.replace in pyspark
Author: Daoyuan Wang <daoyuan.wang@intel.com>
Closes #6003 from adrian-wang/pynareplace and squashes the following commits:
672efba [Daoyuan Wang] remove py2.7 feature
4a148f7 [Daoyuan Wang] to_replace support dict, value support single value, and add full tests
9e232e7 [Daoyuan Wang] rename scala map
af0268a [Daoyuan Wang] remove na
63ac579 [Daoyuan Wang] add na.replace in pyspark
(cherry picked from commit d86ce84584
)
Signed-off-by: Reynold Xin <rxin@databricks.com>
This commit is contained in:
parent
2bbb685f4c
commit
653db0a1bd
|
@ -53,4 +53,11 @@ private[spark] object PythonUtils {
|
||||||
def toSeq[T](cols: JList[T]): Seq[T] = {
|
def toSeq[T](cols: JList[T]): Seq[T] = {
|
||||||
cols.toList.toSeq
|
cols.toList.toSeq
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert java map of K, V into Map of K, V (for calling API with varargs)
|
||||||
|
*/
|
||||||
|
def toScalaMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = {
|
||||||
|
jm.toMap
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -578,6 +578,10 @@ class DataFrame(object):
|
||||||
"""Return a JVM Seq of Columns from a list of Column or names"""
|
"""Return a JVM Seq of Columns from a list of Column or names"""
|
||||||
return _to_seq(self.sql_ctx._sc, cols, converter)
|
return _to_seq(self.sql_ctx._sc, cols, converter)
|
||||||
|
|
||||||
|
def _jmap(self, jm):
|
||||||
|
"""Return a JVM Scala Map from a dict"""
|
||||||
|
return _to_scala_map(self.sql_ctx._sc, jm)
|
||||||
|
|
||||||
def _jcols(self, *cols):
|
def _jcols(self, *cols):
|
||||||
"""Return a JVM Seq of Columns from a list of Column or column names
|
"""Return a JVM Seq of Columns from a list of Column or column names
|
||||||
|
|
||||||
|
@ -924,6 +928,80 @@ class DataFrame(object):
|
||||||
|
|
||||||
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
|
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
|
||||||
|
|
||||||
|
def replace(self, to_replace, value, subset=None):
|
||||||
|
"""Returns a new :class:`DataFrame` replacing a value with another value.
|
||||||
|
|
||||||
|
:param to_replace: int, long, float, string, or list.
|
||||||
|
Value to be replaced.
|
||||||
|
If the value is a dict, then `value` is ignored and `to_replace` must be a
|
||||||
|
mapping from column name (string) to replacement value. The value to be
|
||||||
|
replaced must be an int, long, float, or string.
|
||||||
|
:param value: int, long, float, string, or list.
|
||||||
|
Value to use to replace holes.
|
||||||
|
The replacement value must be an int, long, float, or string. If `value` is a
|
||||||
|
list or tuple, `value` should be of the same length with `to_replace`.
|
||||||
|
:param subset: optional list of column names to consider.
|
||||||
|
Columns specified in subset that do not have matching data type are ignored.
|
||||||
|
For example, if `value` is a string, and subset contains a non-string column,
|
||||||
|
then the non-string column is simply ignored.
|
||||||
|
>>> df4.replace(10, 20).show()
|
||||||
|
+----+------+-----+
|
||||||
|
| age|height| name|
|
||||||
|
+----+------+-----+
|
||||||
|
| 20| 80|Alice|
|
||||||
|
| 5| null| Bob|
|
||||||
|
|null| null| Tom|
|
||||||
|
|null| null| null|
|
||||||
|
+----+------+-----+
|
||||||
|
|
||||||
|
>>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
|
||||||
|
+----+------+----+
|
||||||
|
| age|height|name|
|
||||||
|
+----+------+----+
|
||||||
|
| 10| 80| A|
|
||||||
|
| 5| null| B|
|
||||||
|
|null| null| Tom|
|
||||||
|
|null| null|null|
|
||||||
|
+----+------+----+
|
||||||
|
"""
|
||||||
|
if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)):
|
||||||
|
raise ValueError(
|
||||||
|
"to_replace should be a float, int, long, string, list, tuple, or dict")
|
||||||
|
|
||||||
|
if not isinstance(value, (float, int, long, basestring, list, tuple)):
|
||||||
|
raise ValueError("value should be a float, int, long, string, list, or tuple")
|
||||||
|
|
||||||
|
rep_dict = dict()
|
||||||
|
|
||||||
|
if isinstance(to_replace, (float, int, long, basestring)):
|
||||||
|
to_replace = [to_replace]
|
||||||
|
|
||||||
|
if isinstance(to_replace, tuple):
|
||||||
|
to_replace = list(to_replace)
|
||||||
|
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
value = list(value)
|
||||||
|
|
||||||
|
if isinstance(to_replace, list) and isinstance(value, list):
|
||||||
|
if len(to_replace) != len(value):
|
||||||
|
raise ValueError("to_replace and value lists should be of the same length")
|
||||||
|
rep_dict = dict(zip(to_replace, value))
|
||||||
|
elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)):
|
||||||
|
rep_dict = dict([(tr, value) for tr in to_replace])
|
||||||
|
elif isinstance(to_replace, dict):
|
||||||
|
rep_dict = to_replace
|
||||||
|
|
||||||
|
if subset is None:
|
||||||
|
return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
|
||||||
|
elif isinstance(subset, basestring):
|
||||||
|
subset = [subset]
|
||||||
|
|
||||||
|
if not isinstance(subset, (list, tuple)):
|
||||||
|
raise ValueError("subset should be a list or tuple of column names")
|
||||||
|
|
||||||
|
return DataFrame(
|
||||||
|
self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
|
||||||
|
|
||||||
def corr(self, col1, col2, method=None):
|
def corr(self, col1, col2, method=None):
|
||||||
"""
|
"""
|
||||||
Calculates the correlation of two columns of a DataFrame as a double value. Currently only
|
Calculates the correlation of two columns of a DataFrame as a double value. Currently only
|
||||||
|
@ -1226,6 +1304,13 @@ def _to_seq(sc, cols, converter=None):
|
||||||
return sc._jvm.PythonUtils.toSeq(cols)
|
return sc._jvm.PythonUtils.toSeq(cols)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_scala_map(sc, jm):
|
||||||
|
"""
|
||||||
|
Convert a dict into a JVM Map.
|
||||||
|
"""
|
||||||
|
return sc._jvm.PythonUtils.toScalaMap(jm)
|
||||||
|
|
||||||
|
|
||||||
def _unary_op(name, doc="unary operator"):
|
def _unary_op(name, doc="unary operator"):
|
||||||
""" Create a method for given unary operator """
|
""" Create a method for given unary operator """
|
||||||
def _(self):
|
def _(self):
|
||||||
|
|
|
@ -665,6 +665,54 @@ class SQLTests(ReusedPySparkTestCase):
|
||||||
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
|
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
|
||||||
self.assertEqual(~75, result['~b'])
|
self.assertEqual(~75, result['~b'])
|
||||||
|
|
||||||
|
def test_replace(self):
|
||||||
|
schema = StructType([
|
||||||
|
StructField("name", StringType(), True),
|
||||||
|
StructField("age", IntegerType(), True),
|
||||||
|
StructField("height", DoubleType(), True)])
|
||||||
|
|
||||||
|
# replace with int
|
||||||
|
row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
|
||||||
|
self.assertEqual(row.age, 20)
|
||||||
|
self.assertEqual(row.height, 20.0)
|
||||||
|
|
||||||
|
# replace with double
|
||||||
|
row = self.sqlCtx.createDataFrame(
|
||||||
|
[(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
|
||||||
|
self.assertEqual(row.age, 82)
|
||||||
|
self.assertEqual(row.height, 82.1)
|
||||||
|
|
||||||
|
# replace with string
|
||||||
|
row = self.sqlCtx.createDataFrame(
|
||||||
|
[(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
|
||||||
|
self.assertEqual(row.name, u"Ann")
|
||||||
|
self.assertEqual(row.age, 10)
|
||||||
|
|
||||||
|
# replace with subset specified by a string of a column name w/ actual change
|
||||||
|
row = self.sqlCtx.createDataFrame(
|
||||||
|
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
|
||||||
|
self.assertEqual(row.age, 20)
|
||||||
|
|
||||||
|
# replace with subset specified by a string of a column name w/o actual change
|
||||||
|
row = self.sqlCtx.createDataFrame(
|
||||||
|
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
|
||||||
|
self.assertEqual(row.age, 10)
|
||||||
|
|
||||||
|
# replace with subset specified with one column replaced, another column not in subset
|
||||||
|
# stays unchanged.
|
||||||
|
row = self.sqlCtx.createDataFrame(
|
||||||
|
[(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
|
||||||
|
self.assertEqual(row.name, u'Alice')
|
||||||
|
self.assertEqual(row.age, 20)
|
||||||
|
self.assertEqual(row.height, 10.0)
|
||||||
|
|
||||||
|
# replace with subset specified but no column will be replaced
|
||||||
|
row = self.sqlCtx.createDataFrame(
|
||||||
|
[(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
|
||||||
|
self.assertEqual(row.name, u'Alice')
|
||||||
|
self.assertEqual(row.age, 10)
|
||||||
|
self.assertEqual(row.height, None)
|
||||||
|
|
||||||
|
|
||||||
class HiveContextSQLTests(ReusedPySparkTestCase):
|
class HiveContextSQLTests(ReusedPySparkTestCase):
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue