[SPARK-35809][PYTHON] Add index_col argument for ps.sql

### What changes were proposed in this pull request?

This PR proposes adding an argument `index_col` for `ps.sql` function, to preserve the index when users want.

NOTE that the `reset_index()` have to be performed before using `ps.sql` with `index_col`.

```python
>>> psdf
   A  B
a  1  4
b  2  5
c  3  6
>>> psdf_reset_index = psdf.reset_index()
>>> ps.sql("SELECT * from {psdf_reset_index} WHERE A > 1", index_col="index")
       A  B
index
b      2  5
c      3  6
```

Otherwise, the index is always lost.

```python
>>> ps.sql("SELECT * from {psdf} WHERE A > 1")
   A  B
0  2  5
1  3  6
```

### Why are the changes needed?

Index is one of the key object for the existing pandas users, so we should provide the way to keep the index after computing the `ps.sql`.

### Does this PR introduce _any_ user-facing change?

Yes, the new argument is added.

### How was this patch tested?

Add a unit test and manually check the build pass.

Closes #33450 from itholic/SPARK-35809.

Authored-by: itholic <haejoon.lee@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit 6578f0b135)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
itholic 2021-07-22 17:08:34 +09:00 committed by Hyukjin Kwon
parent 0e94e42cd3
commit 3a18864c5f
2 changed files with 79 additions and 5 deletions

View file

@ -16,7 +16,7 @@
#
import _string # type: ignore
from typing import Any, Dict, Optional # noqa: F401 (SPARK-34943)
from typing import Any, Dict, Optional, Union, List # noqa: F401 (SPARK-34943)
import inspect
import pandas as pd
@ -26,6 +26,8 @@ from pyspark import pandas as ps # For running doctests and reference resolutio
from pyspark.pandas.utils import default_session
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import Series
from pyspark.pandas.internal import InternalFrame
from pyspark.pandas.namespace import _get_index_map
__all__ = ["sql"]
@ -36,6 +38,7 @@ from builtins import locals as builtin_locals
def sql(
query: str,
index_col: Optional[Union[str, List[str]]] = None,
globals: Optional[Dict[str, Any]] = None,
locals: Optional[Dict[str, Any]] = None,
**kwargs: Any
@ -65,6 +68,44 @@ def sql(
----------
query : str
the SQL query
index_col : str or list of str, optional
Column names to be used in Spark to represent pandas-on-Spark's index. The index name
in pandas-on-Spark is ignored. By default, the index is always lost.
.. note:: If you want to preserve the index, explicitly use :func:`DataFrame.reset_index`,
and pass it to the sql statement with `index_col` parameter.
For example,
>>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c'])
>>> psdf_reset_index = psdf.reset_index()
>>> ps.sql("SELECT * FROM {psdf_reset_index}", index_col="index")
... # doctest: +NORMALIZE_WHITESPACE
A B
index
a 1 4
b 2 5
c 3 6
For MultiIndex,
>>> psdf = ps.DataFrame(
... {"A": [1, 2, 3], "B": [4, 5, 6]},
... index=pd.MultiIndex.from_tuples(
... [("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"]
... ),
... )
>>> psdf_reset_index = psdf.reset_index()
>>> ps.sql("SELECT * FROM {psdf_reset_index}", index_col=["index1", "index2"])
... # doctest: +NORMALIZE_WHITESPACE
A B
index1 index2
a b 1 4
c d 2 5
e f 3 6
Also note that the index name(s) should be matched to the existing name.
globals : dict, optional
the dictionary of global variables, if explicitly set by the user
locals : dict, optional
@ -151,7 +192,7 @@ def sql(
_dict.update(_locals)
# Highest order of precedence is the locals
_dict.update(kwargs)
return SQLProcessor(_dict, query, default_session()).execute()
return SQLProcessor(_dict, query, default_session()).execute(index_col)
_CAPTURE_SCOPES = 2
@ -221,12 +262,12 @@ class SQLProcessor(object):
# The normalized form is typically a string
self._cached_vars = {} # type: Dict[str, Any]
# The SQL statement after:
# - all the dataframes have been have been registered as temporary views
# - all the dataframes have been registered as temporary views
# - all the values have been converted normalized to equivalent SQL representations
self._normalized_statement = None # type: Optional[str]
self._session = session
def execute(self) -> DataFrame:
def execute(self, index_col: Optional[Union[str, List[str]]]) -> DataFrame:
"""
Returns a DataFrame for which the SQL statement has been executed by
the underlying SQL engine.
@ -260,7 +301,14 @@ class SQLProcessor(object):
finally:
for v in self._temp_views:
self._session.catalog.dropTempView(v)
return DataFrame(sdf)
index_spark_columns, index_names = _get_index_map(sdf, index_col)
return DataFrame(
InternalFrame(
spark_frame=sdf, index_spark_columns=index_spark_columns, index_names=index_names
)
)
def _convert(self, key: str) -> Any:
"""

View file

@ -37,6 +37,32 @@ class SQLTest(PandasOnSparkTestCase, SQLTestUtils):
with self.assertRaises(ParseException):
ps.sql("this is not valid sql")
def test_sql_with_index_col(self):
import pandas as pd
# Index
psdf = ps.DataFrame(
{"A": [1, 2, 3], "B": [4, 5, 6]}, index=pd.Index(["a", "b", "c"], name="index")
)
psdf_reset_index = psdf.reset_index()
actual = ps.sql("select * from {psdf_reset_index} where A > 1", index_col="index")
expected = psdf.iloc[[1, 2]]
self.assert_eq(actual, expected)
# MultiIndex
psdf = ps.DataFrame(
{"A": [1, 2, 3], "B": [4, 5, 6]},
index=pd.MultiIndex.from_tuples(
[("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"]
),
)
psdf_reset_index = psdf.reset_index()
actual = ps.sql(
"select * from {psdf_reset_index} where A > 1", index_col=["index1", "index2"]
)
expected = psdf.iloc[[1, 2]]
self.assert_eq(actual, expected)
if __name__ == "__main__":
import unittest