[SPARK-7321][SQL] Add Column expression for conditional statements (when/otherwise)
This builds on https://github.com/apache/spark/pull/5932 and should close https://github.com/apache/spark/pull/5932 as well.
As an example:
```python
df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
```
Author: Reynold Xin <rxin@databricks.com>
Author: kaka1992 <kaka_1992@163.com>
Closes #6072 from rxin/when-expr and squashes the following commits:
8f49201 [Reynold Xin] Throw exception if otherwise is applied twice.
0455eda [Reynold Xin] Reset run-tests.
bfb9d9f [Reynold Xin] Updated documentation and test cases.
762f6a5 [Reynold Xin] Merge pull request #5932 from kaka1992/IFCASE
95724c6 [kaka1992] Update
8218d0a [kaka1992] Update
801009e [kaka1992] Update
76d6346 [kaka1992] [SPARK-7321][SQL] Add Column expression for conditional statements (if, case)
(cherry picked from commit 97dee313f2
)
Signed-off-by: Reynold Xin <rxin@databricks.com>
This commit is contained in:
parent
bdd5db9f16
commit
219a9043ef
|
@ -32,6 +32,8 @@ Important classes of Spark SQL and DataFrames:
|
||||||
Aggregation methods, returned by :func:`DataFrame.groupBy`.
|
Aggregation methods, returned by :func:`DataFrame.groupBy`.
|
||||||
- L{DataFrameNaFunctions}
|
- L{DataFrameNaFunctions}
|
||||||
Methods for handling missing data (null values).
|
Methods for handling missing data (null values).
|
||||||
|
- L{DataFrameStatFunctions}
|
||||||
|
Methods for statistics functionality.
|
||||||
- L{functions}
|
- L{functions}
|
||||||
List of built-in functions available for :class:`DataFrame`.
|
List of built-in functions available for :class:`DataFrame`.
|
||||||
- L{types}
|
- L{types}
|
||||||
|
|
|
@ -1546,6 +1546,37 @@ class Column(object):
|
||||||
"""
|
"""
|
||||||
return (self >= lowerBound) & (self <= upperBound)
|
return (self >= lowerBound) & (self <= upperBound)
|
||||||
|
|
||||||
|
@ignore_unicode_prefix
|
||||||
|
def when(self, condition, value):
|
||||||
|
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
|
||||||
|
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
|
||||||
|
|
||||||
|
See :func:`pyspark.sql.functions.when` for example usage.
|
||||||
|
|
||||||
|
:param condition: a boolean :class:`Column` expression.
|
||||||
|
:param value: a literal value, or a :class:`Column` expression.
|
||||||
|
|
||||||
|
"""
|
||||||
|
sc = SparkContext._active_spark_context
|
||||||
|
if not isinstance(condition, Column):
|
||||||
|
raise TypeError("condition should be a Column")
|
||||||
|
v = value._jc if isinstance(value, Column) else value
|
||||||
|
jc = sc._jvm.functions.when(condition._jc, v)
|
||||||
|
return Column(jc)
|
||||||
|
|
||||||
|
@ignore_unicode_prefix
|
||||||
|
def otherwise(self, value):
|
||||||
|
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
|
||||||
|
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
|
||||||
|
|
||||||
|
See :func:`pyspark.sql.functions.when` for example usage.
|
||||||
|
|
||||||
|
:param value: a literal value, or a :class:`Column` expression.
|
||||||
|
"""
|
||||||
|
v = value._jc if isinstance(value, Column) else value
|
||||||
|
jc = self._jc.otherwise(value)
|
||||||
|
return Column(jc)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return 'Column<%s>' % self._jc.toString().encode('utf8')
|
return 'Column<%s>' % self._jc.toString().encode('utf8')
|
||||||
|
|
||||||
|
|
|
@ -32,13 +32,14 @@ from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'approxCountDistinct',
|
'approxCountDistinct',
|
||||||
|
'coalesce',
|
||||||
'countDistinct',
|
'countDistinct',
|
||||||
'monotonicallyIncreasingId',
|
'monotonicallyIncreasingId',
|
||||||
'rand',
|
'rand',
|
||||||
'randn',
|
'randn',
|
||||||
'sparkPartitionId',
|
'sparkPartitionId',
|
||||||
'coalesce',
|
'udf',
|
||||||
'udf']
|
'when']
|
||||||
|
|
||||||
|
|
||||||
def _create_function(name, doc=""):
|
def _create_function(name, doc=""):
|
||||||
|
@ -291,6 +292,27 @@ def struct(*cols):
|
||||||
return Column(jc)
|
return Column(jc)
|
||||||
|
|
||||||
|
|
||||||
|
def when(condition, value):
|
||||||
|
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
|
||||||
|
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
|
||||||
|
|
||||||
|
:param condition: a boolean :class:`Column` expression.
|
||||||
|
:param value: a literal value, or a :class:`Column` expression.
|
||||||
|
|
||||||
|
>>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
|
||||||
|
[Row(age=3), Row(age=4)]
|
||||||
|
|
||||||
|
>>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect()
|
||||||
|
[Row(age=3), Row(age=None)]
|
||||||
|
"""
|
||||||
|
sc = SparkContext._active_spark_context
|
||||||
|
if not isinstance(condition, Column):
|
||||||
|
raise TypeError("condition should be a Column")
|
||||||
|
v = value._jc if isinstance(value, Column) else value
|
||||||
|
jc = sc._jvm.functions.when(condition._jc, v)
|
||||||
|
return Column(jc)
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedFunction(object):
|
class UserDefinedFunction(object):
|
||||||
"""
|
"""
|
||||||
User defined function in Python
|
User defined function in Python
|
||||||
|
|
|
@ -327,6 +327,67 @@ class Column(protected[sql] val expr: Expression) extends Logging {
|
||||||
*/
|
*/
|
||||||
def eqNullSafe(other: Any): Column = this <=> other
|
def eqNullSafe(other: Any): Column = this <=> other
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluates a list of conditions and returns one of multiple possible result expressions.
|
||||||
|
* If otherwise is not defined at the end, null is returned for unmatched conditions.
|
||||||
|
*
|
||||||
|
* {{{
|
||||||
|
* // Example: encoding gender string column into integer.
|
||||||
|
*
|
||||||
|
* // Scala:
|
||||||
|
* people.select(when(people("gender") === "male", 0)
|
||||||
|
* .when(people("gender") === "female", 1)
|
||||||
|
* .otherwise(2))
|
||||||
|
*
|
||||||
|
* // Java:
|
||||||
|
* people.select(when(col("gender").equalTo("male"), 0)
|
||||||
|
* .when(col("gender").equalTo("female"), 1)
|
||||||
|
* .otherwise(2))
|
||||||
|
* }}}
|
||||||
|
*
|
||||||
|
* @group expr_ops
|
||||||
|
*/
|
||||||
|
def when(condition: Column, value: Any):Column = this.expr match {
|
||||||
|
case CaseWhen(branches: Seq[Expression]) =>
|
||||||
|
CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr))
|
||||||
|
case _ =>
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"when() can only be applied on a Column previously generated by when() function")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluates a list of conditions and returns one of multiple possible result expressions.
|
||||||
|
* If otherwise is not defined at the end, null is returned for unmatched conditions.
|
||||||
|
*
|
||||||
|
* {{{
|
||||||
|
* // Example: encoding gender string column into integer.
|
||||||
|
*
|
||||||
|
* // Scala:
|
||||||
|
* people.select(when(people("gender") === "male", 0)
|
||||||
|
* .when(people("gender") === "female", 1)
|
||||||
|
* .otherwise(2))
|
||||||
|
*
|
||||||
|
* // Java:
|
||||||
|
* people.select(when(col("gender").equalTo("male"), 0)
|
||||||
|
* .when(col("gender").equalTo("female"), 1)
|
||||||
|
* .otherwise(2))
|
||||||
|
* }}}
|
||||||
|
*
|
||||||
|
* @group expr_ops
|
||||||
|
*/
|
||||||
|
def otherwise(value: Any):Column = this.expr match {
|
||||||
|
case CaseWhen(branches: Seq[Expression]) =>
|
||||||
|
if (branches.size % 2 == 0) {
|
||||||
|
CaseWhen(branches :+ lit(value).expr)
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"otherwise() can only be applied once on a Column previously generated by when()")
|
||||||
|
}
|
||||||
|
case _ =>
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"otherwise() can only be applied on a Column previously generated by when()")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* True if the current column is between the lower bound and upper bound, inclusive.
|
* True if the current column is between the lower bound and upper bound, inclusive.
|
||||||
*
|
*
|
||||||
|
|
|
@ -419,6 +419,30 @@ object functions {
|
||||||
*/
|
*/
|
||||||
def not(e: Column): Column = !e
|
def not(e: Column): Column = !e
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluates a list of conditions and returns one of multiple possible result expressions.
|
||||||
|
* If otherwise is not defined at the end, null is returned for unmatched conditions.
|
||||||
|
*
|
||||||
|
* {{{
|
||||||
|
* // Example: encoding gender string column into integer.
|
||||||
|
*
|
||||||
|
* // Scala:
|
||||||
|
* people.select(when(people("gender") === "male", 0)
|
||||||
|
* .when(people("gender") === "female", 1)
|
||||||
|
* .otherwise(2))
|
||||||
|
*
|
||||||
|
* // Java:
|
||||||
|
* people.select(when(col("gender").equalTo("male"), 0)
|
||||||
|
* .when(col("gender").equalTo("female"), 1)
|
||||||
|
* .otherwise(2))
|
||||||
|
* }}}
|
||||||
|
*
|
||||||
|
* @group normal_funcs
|
||||||
|
*/
|
||||||
|
def when(condition: Column, value: Any): Column = {
|
||||||
|
CaseWhen(Seq(condition.expr, lit(value).expr))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a random column with i.i.d. samples from U[0.0, 1.0].
|
* Generate a random column with i.i.d. samples from U[0.0, 1.0].
|
||||||
*
|
*
|
||||||
|
|
|
@ -255,6 +255,27 @@ class ColumnExpressionSuite extends QueryTest {
|
||||||
Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil)
|
Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-7321 when conditional statements") {
|
||||||
|
val testData = (1 to 3).map(i => (i, i.toString)).toDF("key", "value")
|
||||||
|
|
||||||
|
checkAnswer(
|
||||||
|
testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)),
|
||||||
|
Seq(Row(-1), Row(-2), Row(0))
|
||||||
|
)
|
||||||
|
|
||||||
|
// Without the ending otherwise, return null for unmatched conditions.
|
||||||
|
// Also test putting a non-literal value in the expression.
|
||||||
|
checkAnswer(
|
||||||
|
testData.select(when($"key" === 1, lit(0) - $"key").when($"key" === 2, -2)),
|
||||||
|
Seq(Row(-1), Row(-2), Row(null))
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test error handling for invalid expressions.
|
||||||
|
intercept[IllegalArgumentException] { $"key".when($"key" === 1, -1) }
|
||||||
|
intercept[IllegalArgumentException] { $"key".otherwise(-1) }
|
||||||
|
intercept[IllegalArgumentException] { when($"key" === 1, -1).otherwise(-1).otherwise(-1) }
|
||||||
|
}
|
||||||
|
|
||||||
test("sqrt") {
|
test("sqrt") {
|
||||||
checkAnswer(
|
checkAnswer(
|
||||||
testData.select(sqrt('key)).orderBy('key.asc),
|
testData.select(sqrt('key)).orderBy('key.asc),
|
||||||
|
|
Loading…
Reference in a new issue