2616d5cc1d
### What changes were proposed in this pull request? Sets up the `mypy` configuration to enable `disallow_untyped_defs` check for pandas APIs on Spark module. ### Why are the changes needed? Currently many functions in the main codes in pandas APIs on Spark module are still missing type annotations and disabled `mypy` check `disallow_untyped_defs`. We should add more type annotations and enable the mypy check. ### Does this PR introduce _any_ user-facing change? Yes. This PR adds more type annotations in pandas APIs on Spark module, which can impact interaction with development tools for users. ### How was this patch tested? The mypy check with a new configuration and existing tests should pass. Closes #32614 from ueshin/issues/SPARK-35465/disallow_untyped_defs. Authored-by: Takuya UESHIN <ueshin@databricks.com> Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
339 lines
11 KiB
Python
339 lines
11 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import _string # type: ignore
|
|
from typing import Any, Dict, Optional # noqa: F401 (SPARK-34943)
|
|
import inspect
|
|
import pandas as pd
|
|
|
|
from pyspark.sql import SparkSession, DataFrame as SDataFrame # noqa: F401 (SPARK-34943)
|
|
|
|
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
|
|
from pyspark.pandas.utils import default_session
|
|
from pyspark.pandas.frame import DataFrame
|
|
from pyspark.pandas.series import Series
|
|
|
|
|
|
__all__ = ["sql"]
|
|
|
|
from builtins import globals as builtin_globals
|
|
from builtins import locals as builtin_locals
|
|
|
|
|
|
def sql(
|
|
query: str,
|
|
globals: Optional[Dict[str, Any]] = None,
|
|
locals: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any
|
|
) -> DataFrame:
|
|
"""
|
|
Execute a SQL query and return the result as a pandas-on-Spark DataFrame.
|
|
|
|
This function also supports embedding Python variables (locals, globals, and parameters)
|
|
in the SQL statement by wrapping them in curly braces. See examples section for details.
|
|
|
|
In addition to the locals, globals and parameters, the function will also attempt
|
|
to determine if the program currently runs in an IPython (or Jupyter) environment
|
|
and to import the variables from this environment. The variables have the same
|
|
precedence as globals.
|
|
|
|
The following variable types are supported:
|
|
|
|
* string
|
|
* int
|
|
* float
|
|
* list, tuple, range of above types
|
|
* pandas-on-Spark DataFrame
|
|
* pandas-on-Spark Series
|
|
* pandas DataFrame
|
|
|
|
Parameters
|
|
----------
|
|
query : str
|
|
the SQL query
|
|
globals : dict, optional
|
|
the dictionary of global variables, if explicitly set by the user
|
|
locals : dict, optional
|
|
the dictionary of local variables, if explicitly set by the user
|
|
kwargs
|
|
other variables that the user may want to set manually that can be referenced in the query
|
|
|
|
Returns
|
|
-------
|
|
pandas-on-Spark DataFrame
|
|
|
|
Examples
|
|
--------
|
|
|
|
Calling a built-in SQL function.
|
|
|
|
>>> ps.sql("select * from range(10) where id > 7")
|
|
id
|
|
0 8
|
|
1 9
|
|
|
|
A query can also reference a local variable or parameter by wrapping them in curly braces:
|
|
|
|
>>> bound1 = 7
|
|
>>> ps.sql("select * from range(10) where id > {bound1} and id < {bound2}", bound2=9)
|
|
id
|
|
0 8
|
|
|
|
You can also wrap a DataFrame with curly braces to query it directly. Note that when you do
|
|
that, the indexes, if any, automatically become top level columns.
|
|
|
|
>>> mydf = ps.range(10)
|
|
>>> x = range(4)
|
|
>>> ps.sql("SELECT * from {mydf} WHERE id IN {x}")
|
|
id
|
|
0 0
|
|
1 1
|
|
2 2
|
|
3 3
|
|
|
|
Queries can also be arbitrarily nested in functions:
|
|
|
|
>>> def statement():
|
|
... mydf2 = ps.DataFrame({"x": range(2)})
|
|
... return ps.sql("SELECT * from {mydf2}")
|
|
>>> statement()
|
|
x
|
|
0 0
|
|
1 1
|
|
|
|
Mixing pandas-on-Spark and pandas DataFrames in a join operation. Note that the index is
|
|
dropped.
|
|
|
|
>>> ps.sql('''
|
|
... SELECT m1.a, m2.b
|
|
... FROM {table1} m1 INNER JOIN {table2} m2
|
|
... ON m1.key = m2.key
|
|
... ORDER BY m1.a, m2.b''',
|
|
... table1=ps.DataFrame({"a": [1,2], "key": ["a", "b"]}),
|
|
... table2=pd.DataFrame({"b": [3,4,5], "key": ["a", "b", "b"]}))
|
|
a b
|
|
0 1 3
|
|
1 2 4
|
|
2 2 5
|
|
|
|
Also, it is possible to query using Series.
|
|
|
|
>>> myser = ps.Series({'a': [1.0, 2.0, 3.0], 'b': [15.0, 30.0, 45.0]})
|
|
>>> ps.sql("SELECT * from {myser}")
|
|
0
|
|
0 [1.0, 2.0, 3.0]
|
|
1 [15.0, 30.0, 45.0]
|
|
"""
|
|
if globals is None:
|
|
globals = _get_ipython_scope()
|
|
_globals = builtin_globals() if globals is None else dict(globals)
|
|
_locals = builtin_locals() if locals is None else dict(locals)
|
|
# The default choice is the globals
|
|
_dict = dict(_globals)
|
|
# The vars:
|
|
_scope = _get_local_scope()
|
|
_dict.update(_scope)
|
|
# Then the locals
|
|
_dict.update(_locals)
|
|
# Highest order of precedence is the locals
|
|
_dict.update(kwargs)
|
|
return SQLProcessor(_dict, query, default_session()).execute()
|
|
|
|
|
|
_CAPTURE_SCOPES = 2
|
|
|
|
|
|
def _get_local_scope() -> Dict[str, Any]:
|
|
# Get 2 scopes above (_get_local_scope -> sql -> ...) to capture the vars there.
|
|
try:
|
|
return inspect.stack()[_CAPTURE_SCOPES][0].f_locals
|
|
except Exception as e:
|
|
# TODO (rxin, thunterdb): use a more narrow scope exception.
|
|
# See https://github.com/pyspark.pandas/pull/448
|
|
return {}
|
|
|
|
|
|
def _get_ipython_scope() -> Dict[str, Any]:
|
|
"""
|
|
Tries to extract the dictionary of variables if the program is running
|
|
in an IPython notebook environment.
|
|
"""
|
|
try:
|
|
from IPython import get_ipython # type: ignore
|
|
|
|
shell = get_ipython()
|
|
return shell.user_ns
|
|
except Exception as e:
|
|
# TODO (rxin, thunterdb): use a more narrow scope exception.
|
|
# See https://github.com/pyspark.pandas/pull/448
|
|
return None
|
|
|
|
|
|
# Originally from pymysql package
|
|
_escape_table = [chr(x) for x in range(128)]
|
|
_escape_table[0] = "\\0"
|
|
_escape_table[ord("\\")] = "\\\\"
|
|
_escape_table[ord("\n")] = "\\n"
|
|
_escape_table[ord("\r")] = "\\r"
|
|
_escape_table[ord("\032")] = "\\Z"
|
|
_escape_table[ord('"')] = '\\"'
|
|
_escape_table[ord("'")] = "\\'"
|
|
|
|
|
|
def escape_sql_string(value: str) -> str:
|
|
"""Escapes value without adding quotes.
|
|
|
|
>>> escape_sql_string("foo\\nbar")
|
|
'foo\\\\nbar'
|
|
|
|
>>> escape_sql_string("'abc'de")
|
|
"\\\\'abc\\\\'de"
|
|
|
|
>>> escape_sql_string('"abc"de')
|
|
'\\\\"abc\\\\"de'
|
|
"""
|
|
return value.translate(_escape_table)
|
|
|
|
|
|
class SQLProcessor(object):
|
|
def __init__(self, scope: Dict[str, Any], statement: str, session: SparkSession):
|
|
self._scope = scope
|
|
self._statement = statement
|
|
# All the temporary views created when executing this statement
|
|
# The key is the name of the variable in {}
|
|
# The value is the cached Spark Dataframe.
|
|
self._temp_views = {} # type: Dict[str, SDataFrame]
|
|
# All the other variables, converted to a normalized form.
|
|
# The normalized form is typically a string
|
|
self._cached_vars = {} # type: Dict[str, Any]
|
|
# The SQL statement after:
|
|
# - all the dataframes have been have been registered as temporary views
|
|
# - all the values have been converted normalized to equivalent SQL representations
|
|
self._normalized_statement = None # type: Optional[str]
|
|
self._session = session
|
|
|
|
def execute(self) -> DataFrame:
|
|
"""
|
|
Returns a DataFrame for which the SQL statement has been executed by
|
|
the underlying SQL engine.
|
|
|
|
>>> str0 = 'abc'
|
|
>>> ps.sql("select {str0}")
|
|
abc
|
|
0 abc
|
|
|
|
>>> str1 = 'abc"abc'
|
|
>>> str2 = "abc'abc"
|
|
>>> ps.sql("select {str0}, {str1}, {str2}")
|
|
abc abc"abc abc'abc
|
|
0 abc abc"abc abc'abc
|
|
|
|
>>> strs = ['a', 'b']
|
|
>>> ps.sql("select 'a' in {strs} as cond1, 'c' in {strs} as cond2")
|
|
cond1 cond2
|
|
0 True False
|
|
"""
|
|
blocks = _string.formatter_parser(self._statement)
|
|
# TODO: use a string builder
|
|
res = ""
|
|
try:
|
|
for (pre, inner, _, _) in blocks:
|
|
var_next = "" if inner is None else self._convert(inner)
|
|
res = res + pre + var_next
|
|
self._normalized_statement = res
|
|
|
|
sdf = self._session.sql(self._normalized_statement)
|
|
finally:
|
|
for v in self._temp_views:
|
|
self._session.catalog.dropTempView(v)
|
|
return DataFrame(sdf)
|
|
|
|
def _convert(self, key: str) -> Any:
|
|
"""
|
|
Given a {} key, returns an equivalent SQL representation.
|
|
This conversion performs all the necessary escaping so that the string
|
|
returned can be directly injected into the SQL statement.
|
|
"""
|
|
# Already cached?
|
|
if key in self._cached_vars:
|
|
return self._cached_vars[key]
|
|
# Analyze:
|
|
if key not in self._scope:
|
|
raise ValueError(
|
|
"The key {} in the SQL statement was not found in global,"
|
|
" local or parameters variables".format(key)
|
|
)
|
|
var = self._scope[key]
|
|
fillin = self._convert_var(var)
|
|
self._cached_vars[key] = fillin
|
|
return fillin
|
|
|
|
def _convert_var(self, var: Any) -> Any:
|
|
"""
|
|
Converts a python object into a string that is legal SQL.
|
|
"""
|
|
if isinstance(var, (int, float)):
|
|
return str(var)
|
|
if isinstance(var, Series):
|
|
return self._convert_var(var.to_dataframe())
|
|
if isinstance(var, pd.DataFrame):
|
|
return self._convert_var(ps.DataFrame(var))
|
|
if isinstance(var, DataFrame):
|
|
df_id = "pandas_on_spark_" + str(id(var))
|
|
if df_id not in self._temp_views:
|
|
sdf = var.to_spark()
|
|
sdf.createOrReplaceTempView(df_id)
|
|
self._temp_views[df_id] = sdf
|
|
return df_id
|
|
if isinstance(var, str):
|
|
return '"' + escape_sql_string(var) + '"'
|
|
if isinstance(var, list):
|
|
return "(" + ", ".join([self._convert_var(v) for v in var]) + ")"
|
|
if isinstance(var, (tuple, range)):
|
|
return self._convert_var(list(var))
|
|
raise ValueError("Unsupported variable type {}: {}".format(type(var).__name__, str(var)))
|
|
|
|
|
|
def _test() -> None:
|
|
import os
|
|
import doctest
|
|
import sys
|
|
from pyspark.sql import SparkSession
|
|
import pyspark.pandas.sql_processor
|
|
|
|
os.chdir(os.environ["SPARK_HOME"])
|
|
|
|
globs = pyspark.pandas.sql_processor.__dict__.copy()
|
|
globs["ps"] = pyspark.pandas
|
|
spark = (
|
|
SparkSession.builder.master("local[4]")
|
|
.appName("pyspark.pandas.sql_processor tests")
|
|
.getOrCreate()
|
|
)
|
|
(failure_count, test_count) = doctest.testmod(
|
|
pyspark.pandas.sql_processor,
|
|
globs=globs,
|
|
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
|
|
)
|
|
spark.stop()
|
|
if failure_count:
|
|
sys.exit(-1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_test()
|