[SPARK-35809][PYTHON] Add index_col
argument for ps.sql
### What changes were proposed in this pull request?
This PR proposes adding an argument `index_col` for `ps.sql` function, to preserve the index when users want.
NOTE that the `reset_index()` have to be performed before using `ps.sql` with `index_col`.
```python
>>> psdf
A B
a 1 4
b 2 5
c 3 6
>>> psdf_reset_index = psdf.reset_index()
>>> ps.sql("SELECT * from {psdf_reset_index} WHERE A > 1", index_col="index")
A B
index
b 2 5
c 3 6
```
Otherwise, the index is always lost.
```python
>>> ps.sql("SELECT * from {psdf} WHERE A > 1")
A B
0 2 5
1 3 6
```
### Why are the changes needed?
Index is one of the key object for the existing pandas users, so we should provide the way to keep the index after computing the `ps.sql`.
### Does this PR introduce _any_ user-facing change?
Yes, the new argument is added.
### How was this patch tested?
Add a unit test and manually check the build pass.
Closes #33450 from itholic/SPARK-35809.
Authored-by: itholic <haejoon.lee@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit 6578f0b135
)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
0e94e42cd3
commit
3a18864c5f
|
@ -16,7 +16,7 @@
|
|||
#
|
||||
|
||||
import _string # type: ignore
|
||||
from typing import Any, Dict, Optional # noqa: F401 (SPARK-34943)
|
||||
from typing import Any, Dict, Optional, Union, List # noqa: F401 (SPARK-34943)
|
||||
import inspect
|
||||
import pandas as pd
|
||||
|
||||
|
@ -26,6 +26,8 @@ from pyspark import pandas as ps # For running doctests and reference resolutio
|
|||
from pyspark.pandas.utils import default_session
|
||||
from pyspark.pandas.frame import DataFrame
|
||||
from pyspark.pandas.series import Series
|
||||
from pyspark.pandas.internal import InternalFrame
|
||||
from pyspark.pandas.namespace import _get_index_map
|
||||
|
||||
|
||||
__all__ = ["sql"]
|
||||
|
@ -36,6 +38,7 @@ from builtins import locals as builtin_locals
|
|||
|
||||
def sql(
|
||||
query: str,
|
||||
index_col: Optional[Union[str, List[str]]] = None,
|
||||
globals: Optional[Dict[str, Any]] = None,
|
||||
locals: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
|
@ -65,6 +68,44 @@ def sql(
|
|||
----------
|
||||
query : str
|
||||
the SQL query
|
||||
index_col : str or list of str, optional
|
||||
Column names to be used in Spark to represent pandas-on-Spark's index. The index name
|
||||
in pandas-on-Spark is ignored. By default, the index is always lost.
|
||||
|
||||
.. note:: If you want to preserve the index, explicitly use :func:`DataFrame.reset_index`,
|
||||
and pass it to the sql statement with `index_col` parameter.
|
||||
|
||||
For example,
|
||||
|
||||
>>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c'])
|
||||
>>> psdf_reset_index = psdf.reset_index()
|
||||
>>> ps.sql("SELECT * FROM {psdf_reset_index}", index_col="index")
|
||||
... # doctest: +NORMALIZE_WHITESPACE
|
||||
A B
|
||||
index
|
||||
a 1 4
|
||||
b 2 5
|
||||
c 3 6
|
||||
|
||||
For MultiIndex,
|
||||
|
||||
>>> psdf = ps.DataFrame(
|
||||
... {"A": [1, 2, 3], "B": [4, 5, 6]},
|
||||
... index=pd.MultiIndex.from_tuples(
|
||||
... [("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"]
|
||||
... ),
|
||||
... )
|
||||
>>> psdf_reset_index = psdf.reset_index()
|
||||
>>> ps.sql("SELECT * FROM {psdf_reset_index}", index_col=["index1", "index2"])
|
||||
... # doctest: +NORMALIZE_WHITESPACE
|
||||
A B
|
||||
index1 index2
|
||||
a b 1 4
|
||||
c d 2 5
|
||||
e f 3 6
|
||||
|
||||
Also note that the index name(s) should be matched to the existing name.
|
||||
|
||||
globals : dict, optional
|
||||
the dictionary of global variables, if explicitly set by the user
|
||||
locals : dict, optional
|
||||
|
@ -151,7 +192,7 @@ def sql(
|
|||
_dict.update(_locals)
|
||||
# Highest order of precedence is the locals
|
||||
_dict.update(kwargs)
|
||||
return SQLProcessor(_dict, query, default_session()).execute()
|
||||
return SQLProcessor(_dict, query, default_session()).execute(index_col)
|
||||
|
||||
|
||||
_CAPTURE_SCOPES = 2
|
||||
|
@ -221,12 +262,12 @@ class SQLProcessor(object):
|
|||
# The normalized form is typically a string
|
||||
self._cached_vars = {} # type: Dict[str, Any]
|
||||
# The SQL statement after:
|
||||
# - all the dataframes have been have been registered as temporary views
|
||||
# - all the dataframes have been registered as temporary views
|
||||
# - all the values have been converted normalized to equivalent SQL representations
|
||||
self._normalized_statement = None # type: Optional[str]
|
||||
self._session = session
|
||||
|
||||
def execute(self) -> DataFrame:
|
||||
def execute(self, index_col: Optional[Union[str, List[str]]]) -> DataFrame:
|
||||
"""
|
||||
Returns a DataFrame for which the SQL statement has been executed by
|
||||
the underlying SQL engine.
|
||||
|
@ -260,7 +301,14 @@ class SQLProcessor(object):
|
|||
finally:
|
||||
for v in self._temp_views:
|
||||
self._session.catalog.dropTempView(v)
|
||||
return DataFrame(sdf)
|
||||
|
||||
index_spark_columns, index_names = _get_index_map(sdf, index_col)
|
||||
|
||||
return DataFrame(
|
||||
InternalFrame(
|
||||
spark_frame=sdf, index_spark_columns=index_spark_columns, index_names=index_names
|
||||
)
|
||||
)
|
||||
|
||||
def _convert(self, key: str) -> Any:
|
||||
"""
|
||||
|
|
|
@ -37,6 +37,32 @@ class SQLTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
with self.assertRaises(ParseException):
|
||||
ps.sql("this is not valid sql")
|
||||
|
||||
def test_sql_with_index_col(self):
|
||||
import pandas as pd
|
||||
|
||||
# Index
|
||||
psdf = ps.DataFrame(
|
||||
{"A": [1, 2, 3], "B": [4, 5, 6]}, index=pd.Index(["a", "b", "c"], name="index")
|
||||
)
|
||||
psdf_reset_index = psdf.reset_index()
|
||||
actual = ps.sql("select * from {psdf_reset_index} where A > 1", index_col="index")
|
||||
expected = psdf.iloc[[1, 2]]
|
||||
self.assert_eq(actual, expected)
|
||||
|
||||
# MultiIndex
|
||||
psdf = ps.DataFrame(
|
||||
{"A": [1, 2, 3], "B": [4, 5, 6]},
|
||||
index=pd.MultiIndex.from_tuples(
|
||||
[("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"]
|
||||
),
|
||||
)
|
||||
psdf_reset_index = psdf.reset_index()
|
||||
actual = ps.sql(
|
||||
"select * from {psdf_reset_index} where A > 1", index_col=["index1", "index2"]
|
||||
)
|
||||
expected = psdf.iloc[[1, 2]]
|
||||
self.assert_eq(actual, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
Loading…
Reference in a new issue