6b0f61563d
Thanks ogirardot, closes #6580
cc rxin JoshRosen
Author: Davies Liu <davies@databricks.com>
Closes #6590 from davies/when and squashes the following commits:
c0f2069 [Davies Liu] fix Column.when() and otherwise()
(cherry picked from commit 605ddbb27c
)
Signed-off-by: Reynold Xin <rxin@databricks.com>
426 lines
13 KiB
Python
426 lines
13 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import sys
|
|
|
|
if sys.version >= '3':
|
|
basestring = str
|
|
long = int
|
|
|
|
from pyspark.context import SparkContext
|
|
from pyspark.rdd import ignore_unicode_prefix
|
|
from pyspark.sql import since
|
|
from pyspark.sql.types import *
|
|
|
|
__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions",
|
|
"DataFrameStatFunctions"]
|
|
|
|
|
|
def _create_column_from_literal(literal):
|
|
sc = SparkContext._active_spark_context
|
|
return sc._jvm.functions.lit(literal)
|
|
|
|
|
|
def _create_column_from_name(name):
|
|
sc = SparkContext._active_spark_context
|
|
return sc._jvm.functions.col(name)
|
|
|
|
|
|
def _to_java_column(col):
|
|
if isinstance(col, Column):
|
|
jcol = col._jc
|
|
else:
|
|
jcol = _create_column_from_name(col)
|
|
return jcol
|
|
|
|
|
|
def _to_seq(sc, cols, converter=None):
|
|
"""
|
|
Convert a list of Column (or names) into a JVM Seq of Column.
|
|
|
|
An optional `converter` could be used to convert items in `cols`
|
|
into JVM Column objects.
|
|
"""
|
|
if converter:
|
|
cols = [converter(c) for c in cols]
|
|
return sc._jvm.PythonUtils.toSeq(cols)
|
|
|
|
|
|
def _unary_op(name, doc="unary operator"):
|
|
""" Create a method for given unary operator """
|
|
def _(self):
|
|
jc = getattr(self._jc, name)()
|
|
return Column(jc)
|
|
_.__doc__ = doc
|
|
return _
|
|
|
|
|
|
def _func_op(name, doc=''):
|
|
def _(self):
|
|
sc = SparkContext._active_spark_context
|
|
jc = getattr(sc._jvm.functions, name)(self._jc)
|
|
return Column(jc)
|
|
_.__doc__ = doc
|
|
return _
|
|
|
|
|
|
def _bin_op(name, doc="binary operator"):
|
|
""" Create a method for given binary operator
|
|
"""
|
|
def _(self, other):
|
|
jc = other._jc if isinstance(other, Column) else other
|
|
njc = getattr(self._jc, name)(jc)
|
|
return Column(njc)
|
|
_.__doc__ = doc
|
|
return _
|
|
|
|
|
|
def _reverse_op(name, doc="binary operator"):
|
|
""" Create a method for binary operator (this object is on right side)
|
|
"""
|
|
def _(self, other):
|
|
jother = _create_column_from_literal(other)
|
|
jc = getattr(jother, name)(self._jc)
|
|
return Column(jc)
|
|
_.__doc__ = doc
|
|
return _
|
|
|
|
|
|
class Column(object):
|
|
|
|
"""
|
|
A column in a DataFrame.
|
|
|
|
:class:`Column` instances can be created by::
|
|
|
|
# 1. Select a column out of a DataFrame
|
|
|
|
df.colName
|
|
df["colName"]
|
|
|
|
# 2. Create from an expression
|
|
df.colName + 1
|
|
1 / df.colName
|
|
|
|
.. note:: Experimental
|
|
|
|
.. versionadded:: 1.3
|
|
"""
|
|
|
|
def __init__(self, jc):
|
|
self._jc = jc
|
|
|
|
# arithmetic operators
|
|
__neg__ = _func_op("negate")
|
|
__add__ = _bin_op("plus")
|
|
__sub__ = _bin_op("minus")
|
|
__mul__ = _bin_op("multiply")
|
|
__div__ = _bin_op("divide")
|
|
__truediv__ = _bin_op("divide")
|
|
__mod__ = _bin_op("mod")
|
|
__radd__ = _bin_op("plus")
|
|
__rsub__ = _reverse_op("minus")
|
|
__rmul__ = _bin_op("multiply")
|
|
__rdiv__ = _reverse_op("divide")
|
|
__rtruediv__ = _reverse_op("divide")
|
|
__rmod__ = _reverse_op("mod")
|
|
|
|
# logistic operators
|
|
__eq__ = _bin_op("equalTo")
|
|
__ne__ = _bin_op("notEqual")
|
|
__lt__ = _bin_op("lt")
|
|
__le__ = _bin_op("leq")
|
|
__ge__ = _bin_op("geq")
|
|
__gt__ = _bin_op("gt")
|
|
|
|
# `and`, `or`, `not` cannot be overloaded in Python,
|
|
# so use bitwise operators as boolean operators
|
|
__and__ = _bin_op('and')
|
|
__or__ = _bin_op('or')
|
|
__invert__ = _func_op('not')
|
|
__rand__ = _bin_op("and")
|
|
__ror__ = _bin_op("or")
|
|
|
|
# container operators
|
|
__contains__ = _bin_op("contains")
|
|
__getitem__ = _bin_op("apply")
|
|
|
|
# bitwise operators
|
|
bitwiseOR = _bin_op("bitwiseOR")
|
|
bitwiseAND = _bin_op("bitwiseAND")
|
|
bitwiseXOR = _bin_op("bitwiseXOR")
|
|
|
|
@since(1.3)
|
|
def getItem(self, key):
|
|
"""
|
|
An expression that gets an item at position ``ordinal`` out of a list,
|
|
or gets an item by key out of a dict.
|
|
|
|
>>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
|
|
>>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
|
|
+----+------+
|
|
|l[0]|d[key]|
|
|
+----+------+
|
|
| 1| value|
|
|
+----+------+
|
|
>>> df.select(df.l[0], df.d["key"]).show()
|
|
+----+------+
|
|
|l[0]|d[key]|
|
|
+----+------+
|
|
| 1| value|
|
|
+----+------+
|
|
"""
|
|
return self[key]
|
|
|
|
@since(1.3)
|
|
def getField(self, name):
|
|
"""
|
|
An expression that gets a field by name in a StructField.
|
|
|
|
>>> from pyspark.sql import Row
|
|
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
|
|
>>> df.select(df.r.getField("b")).show()
|
|
+----+
|
|
|r[b]|
|
|
+----+
|
|
| b|
|
|
+----+
|
|
>>> df.select(df.r.a).show()
|
|
+----+
|
|
|r[a]|
|
|
+----+
|
|
| 1|
|
|
+----+
|
|
"""
|
|
return self[name]
|
|
|
|
def __getattr__(self, item):
|
|
if item.startswith("__"):
|
|
raise AttributeError(item)
|
|
return self.getField(item)
|
|
|
|
# string methods
|
|
rlike = _bin_op("rlike")
|
|
like = _bin_op("like")
|
|
startswith = _bin_op("startsWith")
|
|
endswith = _bin_op("endsWith")
|
|
|
|
@ignore_unicode_prefix
|
|
@since(1.3)
|
|
def substr(self, startPos, length):
|
|
"""
|
|
Return a :class:`Column` which is a substring of the column.
|
|
|
|
:param startPos: start position (int or Column)
|
|
:param length: length of the substring (int or Column)
|
|
|
|
>>> df.select(df.name.substr(1, 3).alias("col")).collect()
|
|
[Row(col=u'Ali'), Row(col=u'Bob')]
|
|
"""
|
|
if type(startPos) != type(length):
|
|
raise TypeError("Can not mix the type")
|
|
if isinstance(startPos, (int, long)):
|
|
jc = self._jc.substr(startPos, length)
|
|
elif isinstance(startPos, Column):
|
|
jc = self._jc.substr(startPos._jc, length._jc)
|
|
else:
|
|
raise TypeError("Unexpected type: %s" % type(startPos))
|
|
return Column(jc)
|
|
|
|
__getslice__ = substr
|
|
|
|
@ignore_unicode_prefix
|
|
@since(1.3)
|
|
def inSet(self, *cols):
|
|
"""
|
|
A boolean expression that is evaluated to true if the value of this
|
|
expression is contained by the evaluated values of the arguments.
|
|
|
|
>>> df[df.name.inSet("Bob", "Mike")].collect()
|
|
[Row(age=5, name=u'Bob')]
|
|
>>> df[df.age.inSet([1, 2, 3])].collect()
|
|
[Row(age=2, name=u'Alice')]
|
|
"""
|
|
if len(cols) == 1 and isinstance(cols[0], (list, set)):
|
|
cols = cols[0]
|
|
cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
|
|
sc = SparkContext._active_spark_context
|
|
jc = getattr(self._jc, "in")(_to_seq(sc, cols))
|
|
return Column(jc)
|
|
|
|
# order
|
|
asc = _unary_op("asc", "Returns a sort expression based on the"
|
|
" ascending order of the given column name.")
|
|
desc = _unary_op("desc", "Returns a sort expression based on the"
|
|
" descending order of the given column name.")
|
|
|
|
isNull = _unary_op("isNull", "True if the current expression is null.")
|
|
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
|
|
|
|
@since(1.3)
|
|
def alias(self, *alias):
|
|
"""
|
|
Returns this column aliased with a new name or names (in the case of expressions that
|
|
return more than one column, such as explode).
|
|
|
|
>>> df.select(df.age.alias("age2")).collect()
|
|
[Row(age2=2), Row(age2=5)]
|
|
"""
|
|
|
|
if len(alias) == 1:
|
|
return Column(getattr(self._jc, "as")(alias[0]))
|
|
else:
|
|
sc = SparkContext._active_spark_context
|
|
return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
|
|
|
|
@ignore_unicode_prefix
|
|
@since(1.3)
|
|
def cast(self, dataType):
|
|
""" Convert the column into type ``dataType``.
|
|
|
|
>>> df.select(df.age.cast("string").alias('ages')).collect()
|
|
[Row(ages=u'2'), Row(ages=u'5')]
|
|
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
|
|
[Row(ages=u'2'), Row(ages=u'5')]
|
|
"""
|
|
if isinstance(dataType, basestring):
|
|
jc = self._jc.cast(dataType)
|
|
elif isinstance(dataType, DataType):
|
|
sc = SparkContext._active_spark_context
|
|
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
|
|
jdt = ssql_ctx.parseDataType(dataType.json())
|
|
jc = self._jc.cast(jdt)
|
|
else:
|
|
raise TypeError("unexpected type: %s" % type(dataType))
|
|
return Column(jc)
|
|
|
|
astype = cast
|
|
|
|
@since(1.3)
|
|
def between(self, lowerBound, upperBound):
|
|
"""
|
|
A boolean expression that is evaluated to true if the value of this
|
|
expression is between the given columns.
|
|
|
|
>>> df.select(df.name, df.age.between(2, 4)).show()
|
|
+-----+--------------------------+
|
|
| name|((age >= 2) && (age <= 4))|
|
|
+-----+--------------------------+
|
|
|Alice| true|
|
|
| Bob| false|
|
|
+-----+--------------------------+
|
|
"""
|
|
return (self >= lowerBound) & (self <= upperBound)
|
|
|
|
@since(1.4)
|
|
def when(self, condition, value):
|
|
"""
|
|
Evaluates a list of conditions and returns one of multiple possible result expressions.
|
|
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
|
|
|
|
See :func:`pyspark.sql.functions.when` for example usage.
|
|
|
|
:param condition: a boolean :class:`Column` expression.
|
|
:param value: a literal value, or a :class:`Column` expression.
|
|
|
|
>>> from pyspark.sql import functions as F
|
|
>>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
|
|
+-----+--------------------------------------------------------+
|
|
| name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0|
|
|
+-----+--------------------------------------------------------+
|
|
|Alice| -1|
|
|
| Bob| 1|
|
|
+-----+--------------------------------------------------------+
|
|
"""
|
|
if not isinstance(condition, Column):
|
|
raise TypeError("condition should be a Column")
|
|
v = value._jc if isinstance(value, Column) else value
|
|
jc = self._jc.when(condition._jc, v)
|
|
return Column(jc)
|
|
|
|
@since(1.4)
|
|
def otherwise(self, value):
|
|
"""
|
|
Evaluates a list of conditions and returns one of multiple possible result expressions.
|
|
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
|
|
|
|
See :func:`pyspark.sql.functions.when` for example usage.
|
|
|
|
:param value: a literal value, or a :class:`Column` expression.
|
|
|
|
>>> from pyspark.sql import functions as F
|
|
>>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
|
|
+-----+---------------------------------+
|
|
| name|CASE WHEN (age > 3) THEN 1 ELSE 0|
|
|
+-----+---------------------------------+
|
|
|Alice| 0|
|
|
| Bob| 1|
|
|
+-----+---------------------------------+
|
|
"""
|
|
v = value._jc if isinstance(value, Column) else value
|
|
jc = self._jc.otherwise(v)
|
|
return Column(jc)
|
|
|
|
@since(1.4)
|
|
def over(self, window):
|
|
"""
|
|
Define a windowing column.
|
|
|
|
:param window: a :class:`WindowSpec`
|
|
:return: a Column
|
|
|
|
>>> from pyspark.sql import Window
|
|
>>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1)
|
|
>>> from pyspark.sql.functions import rank, min
|
|
>>> # df.select(rank().over(window), min('age').over(window))
|
|
|
|
.. note:: Window functions is only supported with HiveContext in 1.4
|
|
"""
|
|
from pyspark.sql.window import WindowSpec
|
|
if not isinstance(window, WindowSpec):
|
|
raise TypeError("window should be WindowSpec")
|
|
jc = self._jc.over(window._jspec)
|
|
return Column(jc)
|
|
|
|
def __repr__(self):
|
|
return 'Column<%s>' % self._jc.toString().encode('utf8')
|
|
|
|
|
|
def _test():
|
|
import doctest
|
|
from pyspark.context import SparkContext
|
|
from pyspark.sql import SQLContext
|
|
import pyspark.sql.column
|
|
globs = pyspark.sql.column.__dict__.copy()
|
|
sc = SparkContext('local[4]', 'PythonTest')
|
|
globs['sc'] = sc
|
|
globs['sqlContext'] = SQLContext(sc)
|
|
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
|
|
.toDF(StructType([StructField('age', IntegerType()),
|
|
StructField('name', StringType())]))
|
|
|
|
(failure_count, test_count) = doctest.testmod(
|
|
pyspark.sql.column, globs=globs,
|
|
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
|
|
globs['sc'].stop()
|
|
if failure_count:
|
|
exit(-1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_test()
|