[SPARK-36683][SQL] Add new built-in SQL functions: SEC and CSC

### What changes were proposed in this pull request?

Add new built-in SQL functions: secant and cosecant, and add them as Scala and Python functions.

### Why are the changes needed?

Cotangent has been supported in Spark SQL but Secant and Cosecant are missing though I believe they can be used as much as cot.
Related Links: [SPARK-20751](https://github.com/apache/spark/pull/17999) [SPARK-36660](https://github.com/apache/spark/pull/33906)

### Does this PR introduce _any_ user-facing change?

Yes, users can now use these functions.

### How was this patch tested?

Unit tests

Closes #33988 from yutoacts/SPARK-36683.

Authored-by: Yuto Akutsu <yuto.akutsu@oss.nttdata.com>
Signed-off-by: Kousuke Saruta <sarutak@oss.nttdata.com>
This commit is contained in:
Yuto Akutsu 2021-09-20 22:38:47 +09:00 committed by Kousuke Saruta
parent 4cc39cfe15
commit 30d17b6333
13 changed files with 299 additions and 36 deletions

View file

@ -394,6 +394,7 @@ Functions
covar_samp
crc32
create_map
csc
cume_dist
current_date
current_timestamp
@ -511,6 +512,7 @@ Functions
rtrim
schema_of_csv
schema_of_json
sec
second
sentences
sequence

View file

@ -253,6 +253,8 @@ def product(col):
def acos(col):
"""
Computes inverse cosine of the input column.
.. versionadded:: 1.4.0
Returns
@ -278,6 +280,8 @@ def acosh(col):
def asin(col):
"""
Computes inverse sine of the input column.
.. versionadded:: 1.3.0
@ -304,6 +308,8 @@ def asinh(col):
def atan(col):
"""
Compute inverse tangent of the input column.
.. versionadded:: 1.4.0
Returns
@ -345,6 +351,8 @@ def ceil(col):
def cos(col):
"""
Computes cosine of the input column.
.. versionadded:: 1.4.0
Parameters
@ -362,6 +370,8 @@ def cos(col):
def cosh(col):
"""
Computes hyperbolic cosine of the input column.
.. versionadded:: 1.4.0
Parameters
@ -379,6 +389,8 @@ def cosh(col):
def cot(col):
"""
Computes cotangent of the input column.
.. versionadded:: 3.3.0
Parameters
@ -394,6 +406,25 @@ def cot(col):
return _invoke_function_over_column("cot", col)
def csc(col):
"""
Computes cosecant of the input column.
.. versionadded:: 3.3.0
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
Angle in radians
Returns
-------
:class:`~pyspark.sql.Column`
Cosecant of the angle.
"""
return _invoke_function_over_column("csc", col)
@since(1.4)
def exp(col):
"""
@ -451,6 +482,25 @@ def rint(col):
return _invoke_function_over_column("rint", col)
def sec(col):
"""
Computes secant of the input column.
.. versionadded:: 3.3.0
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
Angle in radians
Returns
-------
:class:`~pyspark.sql.Column`
Secant of the angle.
"""
return _invoke_function_over_column("sec", col)
@since(1.4)
def signum(col):
"""
@ -461,6 +511,8 @@ def signum(col):
def sin(col):
"""
Computes sine of the input column.
.. versionadded:: 1.4.0
Parameters
@ -477,6 +529,8 @@ def sin(col):
def sinh(col):
"""
Computes hyperbolic sine of the input column.
.. versionadded:: 1.4.0
Parameters
@ -495,6 +549,8 @@ def sinh(col):
def tan(col):
"""
Computes tangent of the input column.
.. versionadded:: 1.4.0
Parameters
@ -512,6 +568,8 @@ def tan(col):
def tanh(col):
"""
Computes hyperbolic tangent of the input column.
.. versionadded:: 1.4.0
Parameters

View file

@ -302,6 +302,7 @@ def cos(col: ColumnOrName) -> Column: ...
def cosh(col: ColumnOrName) -> Column: ...
def cot(col: ColumnOrName) -> Column: ...
def count(col: ColumnOrName) -> Column: ...
def csc(col: ColumnOrName) -> Column: ...
def cume_dist() -> Column: ...
def degrees(col: ColumnOrName) -> Column: ...
def dense_rank() -> Column: ...
@ -339,6 +340,7 @@ def rank() -> Column: ...
def rint(col: ColumnOrName) -> Column: ...
def row_number() -> Column: ...
def rtrim(col: ColumnOrName) -> Column: ...
def sec(col: ColumnOrName) -> Column: ...
def signum(col: ColumnOrName) -> Column: ...
def sin(col: ColumnOrName) -> Column: ...
def sinh(col: ColumnOrName) -> Column: ...

View file

@ -18,13 +18,15 @@
import datetime
from itertools import chain
import re
import math
from py4j.protocol import Py4JJavaError
from pyspark.sql import Row, Window
from pyspark.sql import Row, Window, types
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, \
lit, assert_true, sum_distinct, sumDistinct, shiftleft, shiftLeft, shiftRight, \
shiftright, shiftrightunsigned, shiftRightUnsigned, octet_length, bit_length
from pyspark.testing.sqlutils import ReusedSQLTestCase
shiftright, shiftrightunsigned, shiftRightUnsigned, octet_length, bit_length, \
sec, csc, cot
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
class FunctionsTests(ReusedSQLTestCase):
@ -109,37 +111,28 @@ class FunctionsTests(ReusedSQLTestCase):
def test_math_functions(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
from pyspark.sql import functions
import math
def get_values(l):
return [j[0] for j in l]
def assert_close(a, b):
c = get_values(b)
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
return sum(diff) == len(a)
assert_close([math.cos(i) for i in range(10)],
SQLTestUtils.assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos(df.a)).collect())
assert_close([math.cos(i) for i in range(10)],
SQLTestUtils.assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos("a")).collect())
assert_close([math.sin(i) for i in range(10)],
SQLTestUtils.assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df.a)).collect())
assert_close([math.sin(i) for i in range(10)],
SQLTestUtils.assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df['a'])).collect())
assert_close([math.pow(i, 2 * i) for i in range(10)],
SQLTestUtils.assert_close([math.pow(i, 2 * i) for i in range(10)],
df.select(functions.pow(df.a, df.b)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
SQLTestUtils.assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
SQLTestUtils.assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2.0)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
SQLTestUtils.assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
SQLTestUtils.assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot("a", u"b")).collect())
assert_close([math.hypot(i, 2) for i in range(10)],
SQLTestUtils.assert_close([math.hypot(i, 2) for i in range(10)],
df.select(functions.hypot("a", 2)).collect())
assert_close([math.hypot(i, 2) for i in range(10)],
SQLTestUtils.assert_close([math.hypot(i, 2) for i in range(10)],
df.select(functions.hypot(df.a, 2)).collect())
def test_inverse_trig_functions(self):
@ -157,6 +150,23 @@ class FunctionsTests(ReusedSQLTestCase):
for c in cols:
self.assertIn(f"{alias}(a)", repr(f(c)))
def test_reciprocal_trig_functions(self):
# SPARK-36683: Tests for reciprocal trig functions (SEC, CSC and COT)
lst = [0.0, math.pi / 6, math.pi / 4, math.pi / 3, math.pi / 2,
math.pi, 3 * math.pi / 2, 2 * math.pi]
df = self.spark.createDataFrame(lst, types.DoubleType())
def to_reciprocal_trig(func):
return [1.0 / func(i) if func(i) != 0 else math.inf for i in lst]
SQLTestUtils.assert_close(to_reciprocal_trig(math.cos),
df.select(sec(df.value)).collect())
SQLTestUtils.assert_close(to_reciprocal_trig(math.sin),
df.select(csc(df.value)).collect())
SQLTestUtils.assert_close(to_reciprocal_trig(math.tan),
df.select(cot(df.value)).collect())
def test_rand_functions(self):
df = self.df
from pyspark.sql import functions

View file

@ -16,6 +16,7 @@
#
import datetime
import math
import os
import shutil
import tempfile
@ -243,6 +244,13 @@ class SQLTestUtils(object):
for f in functions:
self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
@staticmethod
def assert_close(a, b):
c = [j[0] for j in b]
diff = [abs(v - c[k]) < 1e-6 if math.isfinite(v) else v == c[k]
for k, v in enumerate(a)]
return sum(diff) == len(a)
class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
@classmethod

View file

@ -361,6 +361,7 @@ object FunctionRegistry {
expression[Ceil]("ceil"),
expression[Ceil]("ceiling", true),
expression[Cos]("cos"),
expression[Sec]("sec"),
expression[Cosh]("cosh"),
expression[Conv]("conv"),
expression[ToDegrees]("degrees"),
@ -392,6 +393,7 @@ object FunctionRegistry {
expression[Signum]("sign", true),
expression[Signum]("signum"),
expression[Sin]("sin"),
expression[Csc]("csc"),
expression[Sinh]("sinh"),
expression[StringToMap]("str_to_map"),
expression[Sqrt]("sqrt"),

View file

@ -299,6 +299,29 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") {
override protected def withNewChildInternal(newChild: Expression): Cos = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the secant of `expr`, as if computed by `1/java.lang.Math.cos`.
""",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(0);
1.0
""",
since = "3.3.0",
group = "math_funcs")
case class Sec(child: Expression)
extends UnaryMathExpression((x: Double) => 1 / math.cos(x), "SEC") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.cos($c);")
}
override protected def withNewChildInternal(newChild: Expression): Sec = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the hyperbolic cosine of `expr`, as if computed by
@ -655,6 +678,29 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") {
override protected def withNewChildInternal(newChild: Expression): Sin = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the cosecant of `expr`, as if computed by `1/java.lang.Math.sin`.
""",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(1);
1.1883951057781212
""",
since = "3.3.0",
group = "math_funcs")
case class Csc(child: Expression)
extends UnaryMathExpression((x: Double) => 1 / math.sin(x), "CSC") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.sin($c);")
}
override protected def withNewChildInternal(newChild: Expression): Csc = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns hyperbolic sine of `expr`, as if computed by `java.lang.Math._FUNC_`.

View file

@ -187,6 +187,20 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Sin, DoubleType)
}
test("csc") {
def f: (Double) => Double = (x: Double) => 1 / math.sin(x)
testUnary(Csc, f)
checkConsistencyBetweenInterpretedAndCodegen(Csc, DoubleType)
val nullLit = Literal.create(null, NullType)
val intNullLit = Literal.create(null, IntegerType)
val intLit = Literal.create(1, IntegerType)
checkEvaluation(checkDataTypeAndCast(Csc(nullLit)), null, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Csc(intNullLit)), null, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Csc(intLit)), 1 / math.sin(1), EmptyRow)
checkEvaluation(checkDataTypeAndCast(Csc(-intLit)), 1 / math.sin(-1), EmptyRow)
checkEvaluation(checkDataTypeAndCast(Csc(0)), 1 / math.sin(0), EmptyRow)
}
test("asin") {
testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1))
testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true)
@ -215,6 +229,20 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType)
}
test("sec") {
def f: (Double) => Double = (x: Double) => 1 / math.cos(x)
testUnary(Sec, f)
checkConsistencyBetweenInterpretedAndCodegen(Sec, DoubleType)
val nullLit = Literal.create(null, NullType)
val intNullLit = Literal.create(null, IntegerType)
val intLit = Literal.create(1, IntegerType)
checkEvaluation(checkDataTypeAndCast(Sec(nullLit)), null, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Sec(intNullLit)), null, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Sec(intLit)), 1 / math.cos(1), EmptyRow)
checkEvaluation(checkDataTypeAndCast(Sec(-intLit)), 1 / math.cos(-1), EmptyRow)
checkEvaluation(checkDataTypeAndCast(Sec(0)), 1 / math.cos(0), EmptyRow)
}
test("acos") {
testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1))
testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true)

View file

@ -1809,6 +1809,15 @@ object functions {
*/
def cot(e: Column): Column = withExpr { Cot(e.expr) }
/**
* @param e angle in radians
* @return cosecant of the angle
*
* @group math_funcs
* @since 3.3.0
*/
def csc(e: Column): Column = withExpr { Csc(e.expr) }
/**
* Computes the exponential of the given value.
*
@ -2197,6 +2206,15 @@ object functions {
*/
def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) }
/**
* @param e angle in radians
* @return secant of the angle
*
* @group math_funcs
* @since 3.3.0
*/
def sec(e: Column): Column = withExpr { Sec(e.expr) }
/**
* Shift the given value numBits left. If the given value is a long value, this function
* will return a long value else it will return an integer value.

View file

@ -1,6 +1,6 @@
<!-- Automatically generated by ExpressionsSchemaSuite -->
## Summary
- Number of queries: 362
- Number of queries: 364
- Number of expressions that missing example: 12
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
## Schema of Built-in Functions
@ -82,6 +82,7 @@
| org.apache.spark.sql.catalyst.expressions.CreateMap | map | SELECT map(1.0, '2', 3.0, '4') | struct<map(1.0, 2, 3.0, 4):map<decimal(2,1),string>> |
| org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | named_struct | SELECT named_struct("a", 1, "b", 2, "c", 3) | struct<named_struct(a, 1, b, 2, c, 3):struct<a:int,b:int,c:int>> |
| org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | struct | SELECT struct(1, 2, 3) | struct<struct(1, 2, 3):struct<col1:int,col2:int,col3:int>> |
| org.apache.spark.sql.catalyst.expressions.Csc | csc | SELECT csc(1) | struct<CSC(1):double> |
| org.apache.spark.sql.catalyst.expressions.CsvToStructs | from_csv | SELECT from_csv('1, 0.8', 'a INT, b DOUBLE') | struct<from_csv(1, 0.8):struct<a:int,b:double>> |
| org.apache.spark.sql.catalyst.expressions.CumeDist | cume_dist | SELECT a, b, cume_dist() OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct<a:string,b:int,cume_dist() OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):double> |
| org.apache.spark.sql.catalyst.expressions.CurrentCatalog | current_catalog | SELECT current_catalog() | struct<current_catalog():string> |
@ -241,6 +242,7 @@
| org.apache.spark.sql.catalyst.expressions.RowNumber | row_number | SELECT a, b, row_number() OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct<a:string,b:int,row_number() OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int> |
| org.apache.spark.sql.catalyst.expressions.SchemaOfCsv | schema_of_csv | SELECT schema_of_csv('1,abc') | struct<schema_of_csv(1,abc):string> |
| org.apache.spark.sql.catalyst.expressions.SchemaOfJson | schema_of_json | SELECT schema_of_json('[{"col":0}]') | struct<schema_of_json([{"col":0}]):string> |
| org.apache.spark.sql.catalyst.expressions.Sec | sec | SELECT sec(0) | struct<SEC(0):double> |
| org.apache.spark.sql.catalyst.expressions.Second | second | SELECT second('2009-07-30 12:58:59') | struct<second(2009-07-30 12:58:59):int> |
| org.apache.spark.sql.catalyst.expressions.SecondsToTimestamp | timestamp_seconds | SELECT timestamp_seconds(1230219000) | struct<timestamp_seconds(1230219000):timestamp> |
| org.apache.spark.sql.catalyst.expressions.Sentences | sentences | SELECT sentences('Hi there! Good morning.') | struct<sentences(Hi there! Good morning., , ):array<array<string>>> |

View file

@ -40,6 +40,14 @@ select 5 % 3;
select pmod(-7, 3);
-- math functions
select sec(1);
select sec(null);
select sec(0);
select sec(-1);
select csc(1);
select csc(null);
select csc(0);
select csc(-1);
select cot(1);
select cot(null);
select cot(0);

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 69
-- Number of queries: 77
-- !query
@ -258,6 +258,70 @@ struct<pmod(-7, 3):int>
2
-- !query
select sec(1)
-- !query schema
struct<SEC(1):double>
-- !query output
1.8508157176809255
-- !query
select sec(null)
-- !query schema
struct<SEC(NULL):double>
-- !query output
NULL
-- !query
select sec(0)
-- !query schema
struct<SEC(0):double>
-- !query output
1.0
-- !query
select sec(-1)
-- !query schema
struct<SEC(-1):double>
-- !query output
1.8508157176809255
-- !query
select csc(1)
-- !query schema
struct<CSC(1):double>
-- !query output
1.1883951057781212
-- !query
select csc(null)
-- !query schema
struct<CSC(NULL):double>
-- !query output
NULL
-- !query
select csc(0)
-- !query schema
struct<CSC(0):double>
-- !query output
Infinity
-- !query
select csc(-1)
-- !query schema
struct<CSC(-1):double>
-- !query output
-1.1883951057781212
-- !query
select cot(1)
-- !query schema

View file

@ -117,6 +117,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
testOneToOneMathFunction(sin, math.sin)
}
test("csc") {
testOneToOneMathFunction(csc,
(x: Double) => (1 / math.sin(x)) )
}
test("asin") {
testOneToOneMathFunction(asin, math.asin)
}
@ -134,6 +139,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
testOneToOneMathFunction(cos, math.cos)
}
test("sec") {
testOneToOneMathFunction(sec,
(x: Double) => (1 / math.cos(x)) )
}
test("acos") {
testOneToOneMathFunction(acos, math.acos)
}
@ -151,6 +161,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
testOneToOneMathFunction(tan, math.tan)
}
test("cot") {
testOneToOneMathFunction(cot,
(x: Double) => (1 / math.tan(x)) )
}
test("atan") {
testOneToOneMathFunction(atan, math.atan)
}