[SPARK-7294][SQL] ADD BETWEEN
Author: 云峤 <chensong.cs@alibaba-inc.com>
Author: kaka1992 <kaka_1992@163.com>
Closes #5839 from kaka1992/master and squashes the following commits:
b15360d [kaka1992] Fix python unit test in sql/test. =_= I forget to commit this file last time.
f928816 [kaka1992] Fix python style in sql/test.
d2e7f72 [kaka1992] Fix python style in sql/test.
c54d904 [kaka1992] Fix empty map bug.
7e64d1e [云峤] Update
7b9b858 [云峤] undo
f080f8d [云峤] update pep8
76f0c51 [云峤] Merge remote-tracking branch 'remotes/upstream/master'
7d62368 [云峤] [SPARK-7294] ADD BETWEEN
baf839b [云峤] [SPARK-7294] ADD BETWEEN
d11d5b9 [云峤] [SPARK-7294] ADD BETWEEN
(cherry picked from commit 735bc3d042
)
Signed-off-by: Reynold Xin <rxin@databricks.com>
This commit is contained in:
parent
8109c9e105
commit
c68d0e2352
|
@ -1405,6 +1405,13 @@ class Column(object):
|
|||
raise TypeError("unexpected type: %s" % type(dataType))
|
||||
return Column(jc)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def between(self, lowerBound, upperBound):
|
||||
""" A boolean expression that is evaluated to true if the value of this
|
||||
expression is between the given columns.
|
||||
"""
|
||||
return (self >= lowerBound) & (self <= upperBound)
|
||||
|
||||
def __repr__(self):
|
||||
return 'Column<%s>' % self._jc.toString().encode('utf8')
|
||||
|
||||
|
|
|
@ -453,6 +453,14 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
for row in rndn:
|
||||
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
|
||||
|
||||
def test_between_function(self):
|
||||
df = self.sc.parallelize([
|
||||
Row(a=1, b=2, c=3),
|
||||
Row(a=2, b=1, c=3),
|
||||
Row(a=4, b=1, c=4)]).toDF()
|
||||
self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
|
||||
df.filter(df.a.between(df.b, df.c)).collect())
|
||||
|
||||
def test_save_and_load(self):
|
||||
df = self.df
|
||||
tmpPath = tempfile.mkdtemp()
|
||||
|
|
|
@ -295,6 +295,15 @@ class Column(protected[sql] val expr: Expression) extends Logging {
|
|||
*/
|
||||
def eqNullSafe(other: Any): Column = this <=> other
|
||||
|
||||
/**
|
||||
* True if the current column is between the lower bound and upper bound, inclusive.
|
||||
*
|
||||
* @group java_expr_ops
|
||||
*/
|
||||
def between(lowerBound: Any, upperBound: Any): Column = {
|
||||
(this >= lowerBound) && (this <= upperBound)
|
||||
}
|
||||
|
||||
/**
|
||||
* True if the current expression is null.
|
||||
*
|
||||
|
|
|
@ -208,6 +208,20 @@ class ColumnExpressionSuite extends QueryTest {
|
|||
testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1)))
|
||||
}
|
||||
|
||||
test("between") {
|
||||
val testData = TestSQLContext.sparkContext.parallelize(
|
||||
(0, 1, 2) ::
|
||||
(1, 2, 3) ::
|
||||
(2, 1, 0) ::
|
||||
(2, 2, 4) ::
|
||||
(3, 1, 6) ::
|
||||
(3, 2, 0) :: Nil).toDF("a", "b", "c")
|
||||
val expectAnswer = testData.collect().toSeq.
|
||||
filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2))
|
||||
|
||||
checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer)
|
||||
}
|
||||
|
||||
val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
|
||||
Row(false, false) ::
|
||||
Row(false, true) ::
|
||||
|
|
Loading…
Reference in a new issue