[SPARK-36813][SQL][PYTHON] Propose an infrastructure of as-of join and imlement ps.merge_asof

### What changes were proposed in this pull request?

Proposes an infrastructure for as-of join and implements `ps.merge_asof` here.

1. Introduce `AsOfJoin` logical plan
2. Rewrite the plan in the optimize phase:

- From something like (SQL syntax is not determied):

```sql
SELECT * FROM left ASOF JOIN right ON (condition, as_of on(left.t, right.t), tolerance)
```

- To

```sql
SELECT left.*, __right__.*
FROM (
     SELECT
          left.*,
          (
               SELECT MIN_BY(STRUCT(right.*), left.t - right.t) AS __nearest_right__
               FROM right
               WHERE condition AND left.t >= right.t AND right.t >= left.t - tolerance
          ) as __right__
     FROM left
     )
WHERE __right__ IS NOT NULL
```

3. The rewritten scalar-subquery will be handled by the existing decorrelation framework.

Note: APIs on SQL DataFrames and SQL syntax are TBD (e.g., [SPARK-22947](https://issues.apache.org/jira/browse/SPARK-22947)), although there are temporary APIs added here.

### Why are the changes needed?

Pandas' `merge_asof` or as-of join for SQL/DataFrame is useful for time series analysis.

### Does this PR introduce _any_ user-facing change?

Yes. `ps.merge_asof` can be used.

```py
>>> quotes
                     time ticker     bid     ask
0 2016-05-25 13:30:00.023   GOOG  720.50  720.93
1 2016-05-25 13:30:00.023   MSFT   51.95   51.96
2 2016-05-25 13:30:00.030   MSFT   51.97   51.98
3 2016-05-25 13:30:00.041   MSFT   51.99   52.00
4 2016-05-25 13:30:00.048   GOOG  720.50  720.93
5 2016-05-25 13:30:00.049   AAPL   97.99   98.01
6 2016-05-25 13:30:00.072   GOOG  720.50  720.88
7 2016-05-25 13:30:00.075   MSFT   52.01   52.03

>>> trades
                     time ticker   price  quantity
0 2016-05-25 13:30:00.023   MSFT   51.95        75
1 2016-05-25 13:30:00.038   MSFT   51.95       155
2 2016-05-25 13:30:00.048   GOOG  720.77       100
3 2016-05-25 13:30:00.048   GOOG  720.92       100
4 2016-05-25 13:30:00.048   AAPL   98.00       100

>>> ps.merge_asof(
...    trades, quotes, on="time", by="ticker"
... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True)
                     time ticker   price  quantity     bid     ask
0 2016-05-25 13:30:00.023   MSFT   51.95        75   51.95   51.96
1 2016-05-25 13:30:00.038   MSFT   51.95       155   51.97   51.98
2 2016-05-25 13:30:00.048   AAPL   98.00       100     NaN     NaN
3 2016-05-25 13:30:00.048   GOOG  720.77       100  720.50  720.93
4 2016-05-25 13:30:00.048   GOOG  720.92       100  720.50  720.93

>>> ps.merge_asof(
...     trades,
...     quotes,
...     on="time",
...     by="ticker",
...     tolerance=F.expr("INTERVAL 2 MILLISECONDS")  # pd.Timedelta("2ms")
... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True)
                     time ticker   price  quantity     bid     ask
0 2016-05-25 13:30:00.023   MSFT   51.95        75   51.95   51.96
1 2016-05-25 13:30:00.038   MSFT   51.95       155     NaN     NaN
2 2016-05-25 13:30:00.048   AAPL   98.00       100     NaN     NaN
3 2016-05-25 13:30:00.048   GOOG  720.77       100  720.50  720.93
4 2016-05-25 13:30:00.048   GOOG  720.92       100  720.50  720.93
```

Note: As `IntervalType` literal is not supported yet, we have to specify the `IntervalType` value with `F.expr` as a workaround.

### How was this patch tested?

Added tests.

Closes #34053 from ueshin/issues/SPARK-36813/merge_asof.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Takuya UESHIN 2021-09-29 09:27:38 +09:00 committed by Hyukjin Kwon
parent a9b4c27f12
commit 05c0fa5738
14 changed files with 1564 additions and 34 deletions

View file

@ -73,6 +73,7 @@ from pyspark.pandas.utils import (
align_diff_frames,
default_session,
is_name_like_tuple,
is_name_like_value,
name_like_string,
same_anchor,
scol_for,
@ -83,11 +84,13 @@ from pyspark.pandas.internal import (
InternalFrame,
DEFAULT_SERIES_NAME,
HIDDEN_COLUMNS,
SPARK_INDEX_NAME_FORMAT,
)
from pyspark.pandas.series import Series, first_series
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
from pyspark.pandas.indexes import Index, DatetimeIndex
from pyspark.pandas.indexes.multi import MultiIndex
__all__ = [
@ -115,6 +118,7 @@ __all__ = [
"read_sql",
"read_json",
"merge",
"merge_asof",
"to_numeric",
"broadcast",
"read_orc",
@ -2747,6 +2751,499 @@ def merge(
)
def merge_asof(
left: Union[DataFrame, Series],
right: Union[DataFrame, Series],
on: Optional[Name] = None,
left_on: Optional[Name] = None,
right_on: Optional[Name] = None,
left_index: bool = False,
right_index: bool = False,
by: Optional[Union[Name, List[Name]]] = None,
left_by: Optional[Union[Name, List[Name]]] = None,
right_by: Optional[Union[Name, List[Name]]] = None,
suffixes: Tuple[str, str] = ("_x", "_y"),
tolerance: Optional[Any] = None,
allow_exact_matches: bool = True,
direction: str = "backward",
) -> DataFrame:
"""
Perform an asof merge.
This is similar to a left-join except that we match on nearest
key rather than equal keys.
For each row in the left DataFrame:
- A "backward" search selects the last row in the right DataFrame whose
'on' key is less than or equal to the left's key.
- A "forward" search selects the first row in the right DataFrame whose
'on' key is greater than or equal to the left's key.
- A "nearest" search selects the row in the right DataFrame whose 'on'
key is closest in absolute distance to the left's key.
Optionally match on equivalent keys with 'by' before searching with 'on'.
.. versionadded:: 3.3.0
Parameters
----------
left : DataFrame or named Series
right : DataFrame or named Series
on : label
Field name to join on. Must be found in both DataFrames.
The data MUST be ordered. Furthermore this must be a numeric column,
such as datetimelike, integer, or float. On or left_on/right_on
must be given.
left_on : label
Field name to join on in left DataFrame.
right_on : label
Field name to join on in right DataFrame.
left_index : bool
Use the index of the left DataFrame as the join key.
right_index : bool
Use the index of the right DataFrame as the join key.
by : column name or list of column names
Match on these columns before performing merge operation.
left_by : column name
Field names to match on in the left DataFrame.
right_by : column name
Field names to match on in the right DataFrame.
suffixes : 2-length sequence (tuple, list, ...)
Suffix to apply to overlapping column names in the left and right
side, respectively.
tolerance : int or Timedelta, optional, default None
Select asof tolerance within this range; must be compatible
with the merge index.
allow_exact_matches : bool, default True
- If True, allow matching with the same 'on' value
(i.e. less-than-or-equal-to / greater-than-or-equal-to)
- If False, don't match the same 'on' value
(i.e., strictly less-than / strictly greater-than).
direction : 'backward' (default), 'forward', or 'nearest'
Whether to search for prior, subsequent, or closest matches.
Returns
-------
merged : DataFrame
See Also
--------
merge : Merge with a database-style join.
merge_ordered : Merge with optional filling/interpolation.
Examples
--------
>>> left = ps.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]})
>>> left
a left_val
0 1 a
1 5 b
2 10 c
>>> right = ps.DataFrame({"a": [1, 2, 3, 6, 7], "right_val": [1, 2, 3, 6, 7]})
>>> right
a right_val
0 1 1
1 2 2
2 3 3
3 6 6
4 7 7
>>> ps.merge_asof(left, right, on="a").sort_values("a").reset_index(drop=True)
a left_val right_val
0 1 a 1
1 5 b 3
2 10 c 7
>>> ps.merge_asof(
... left,
... right,
... on="a",
... allow_exact_matches=False
... ).sort_values("a").reset_index(drop=True)
a left_val right_val
0 1 a NaN
1 5 b 3.0
2 10 c 7.0
>>> ps.merge_asof(
... left,
... right,
... on="a",
... direction="forward"
... ).sort_values("a").reset_index(drop=True)
a left_val right_val
0 1 a 1.0
1 5 b 6.0
2 10 c NaN
>>> ps.merge_asof(
... left,
... right,
... on="a",
... direction="nearest"
... ).sort_values("a").reset_index(drop=True)
a left_val right_val
0 1 a 1
1 5 b 6
2 10 c 7
We can use indexed DataFrames as well.
>>> left = ps.DataFrame({"left_val": ["a", "b", "c"]}, index=[1, 5, 10])
>>> left
left_val
1 a
5 b
10 c
>>> right = ps.DataFrame({"right_val": [1, 2, 3, 6, 7]}, index=[1, 2, 3, 6, 7])
>>> right
right_val
1 1
2 2
3 3
6 6
7 7
>>> ps.merge_asof(left, right, left_index=True, right_index=True).sort_index()
left_val right_val
1 a 1
5 b 3
10 c 7
Here is a real-world times-series example
>>> quotes = ps.DataFrame(
... {
... "time": [
... pd.Timestamp("2016-05-25 13:30:00.023"),
... pd.Timestamp("2016-05-25 13:30:00.023"),
... pd.Timestamp("2016-05-25 13:30:00.030"),
... pd.Timestamp("2016-05-25 13:30:00.041"),
... pd.Timestamp("2016-05-25 13:30:00.048"),
... pd.Timestamp("2016-05-25 13:30:00.049"),
... pd.Timestamp("2016-05-25 13:30:00.072"),
... pd.Timestamp("2016-05-25 13:30:00.075")
... ],
... "ticker": [
... "GOOG",
... "MSFT",
... "MSFT",
... "MSFT",
... "GOOG",
... "AAPL",
... "GOOG",
... "MSFT"
... ],
... "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01],
... "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03]
... }
... )
>>> quotes
time ticker bid ask
0 2016-05-25 13:30:00.023 GOOG 720.50 720.93
1 2016-05-25 13:30:00.023 MSFT 51.95 51.96
2 2016-05-25 13:30:00.030 MSFT 51.97 51.98
3 2016-05-25 13:30:00.041 MSFT 51.99 52.00
4 2016-05-25 13:30:00.048 GOOG 720.50 720.93
5 2016-05-25 13:30:00.049 AAPL 97.99 98.01
6 2016-05-25 13:30:00.072 GOOG 720.50 720.88
7 2016-05-25 13:30:00.075 MSFT 52.01 52.03
>>> trades = ps.DataFrame(
... {
... "time": [
... pd.Timestamp("2016-05-25 13:30:00.023"),
... pd.Timestamp("2016-05-25 13:30:00.038"),
... pd.Timestamp("2016-05-25 13:30:00.048"),
... pd.Timestamp("2016-05-25 13:30:00.048"),
... pd.Timestamp("2016-05-25 13:30:00.048")
... ],
... "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"],
... "price": [51.95, 51.95, 720.77, 720.92, 98.0],
... "quantity": [75, 155, 100, 100, 100]
... }
... )
>>> trades
time ticker price quantity
0 2016-05-25 13:30:00.023 MSFT 51.95 75
1 2016-05-25 13:30:00.038 MSFT 51.95 155
2 2016-05-25 13:30:00.048 GOOG 720.77 100
3 2016-05-25 13:30:00.048 GOOG 720.92 100
4 2016-05-25 13:30:00.048 AAPL 98.00 100
By default we are taking the asof of the quotes
>>> ps.merge_asof(
... trades, quotes, on="time", by="ticker"
... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True)
time ticker price quantity bid ask
0 2016-05-25 13:30:00.023 MSFT 51.95 75 51.95 51.96
1 2016-05-25 13:30:00.038 MSFT 51.95 155 51.97 51.98
2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN
3 2016-05-25 13:30:00.048 GOOG 720.77 100 720.50 720.93
4 2016-05-25 13:30:00.048 GOOG 720.92 100 720.50 720.93
We only asof within 2ms between the quote time and the trade time
>>> ps.merge_asof(
... trades,
... quotes,
... on="time",
... by="ticker",
... tolerance=F.expr("INTERVAL 2 MILLISECONDS") # pd.Timedelta("2ms")
... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True)
time ticker price quantity bid ask
0 2016-05-25 13:30:00.023 MSFT 51.95 75 51.95 51.96
1 2016-05-25 13:30:00.038 MSFT 51.95 155 NaN NaN
2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN
3 2016-05-25 13:30:00.048 GOOG 720.77 100 720.50 720.93
4 2016-05-25 13:30:00.048 GOOG 720.92 100 720.50 720.93
We only asof within 10ms between the quote time and the trade time
and we exclude exact matches on time. However *prior* data will
propagate forward
>>> ps.merge_asof(
... trades,
... quotes,
... on="time",
... by="ticker",
... tolerance=F.expr("INTERVAL 10 MILLISECONDS"), # pd.Timedelta("10ms")
... allow_exact_matches=False
... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True)
time ticker price quantity bid ask
0 2016-05-25 13:30:00.023 MSFT 51.95 75 NaN NaN
1 2016-05-25 13:30:00.038 MSFT 51.95 155 51.97 51.98
2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN
3 2016-05-25 13:30:00.048 GOOG 720.77 100 NaN NaN
4 2016-05-25 13:30:00.048 GOOG 720.92 100 NaN NaN
"""
def to_list(os: Optional[Union[Name, List[Name]]]) -> List[Label]:
if os is None:
return []
elif is_name_like_tuple(os):
return [os] # type: ignore
elif is_name_like_value(os):
return [(os,)]
else:
return [o if is_name_like_tuple(o) else (o,) for o in os]
if isinstance(left, Series):
left = left.to_frame()
if isinstance(right, Series):
right = right.to_frame()
if on:
if left_on or right_on:
raise ValueError(
'Can only pass argument "on" OR "left_on" and "right_on", '
"not a combination of both."
)
left_as_of_names = list(map(left._internal.spark_column_name_for, to_list(on)))
right_as_of_names = list(map(right._internal.spark_column_name_for, to_list(on)))
else:
if left_index:
if isinstance(left.index, MultiIndex):
raise ValueError("left can only have one index")
left_as_of_names = left._internal.index_spark_column_names
else:
left_as_of_names = list(map(left._internal.spark_column_name_for, to_list(left_on)))
if right_index:
if isinstance(right.index, MultiIndex):
raise ValueError("right can only have one index")
right_as_of_names = right._internal.index_spark_column_names
else:
right_as_of_names = list(map(right._internal.spark_column_name_for, to_list(right_on)))
if left_as_of_names and not right_as_of_names:
raise ValueError("Must pass right_on or right_index=True")
if right_as_of_names and not left_as_of_names:
raise ValueError("Must pass left_on or left_index=True")
if not left_as_of_names and not right_as_of_names:
common = list(left.columns.intersection(right.columns))
if len(common) == 0:
raise ValueError(
"No common columns to perform merge on. Merge options: "
"left_on=None, right_on=None, left_index=False, right_index=False"
)
left_as_of_names = list(map(left._internal.spark_column_name_for, to_list(common)))
right_as_of_names = list(map(right._internal.spark_column_name_for, to_list(common)))
if len(left_as_of_names) != 1:
raise ValueError("can only asof on a key for left")
if len(right_as_of_names) != 1:
raise ValueError("can only asof on a key for right")
if by:
if left_by or right_by:
raise ValueError('Can only pass argument "on" OR "left_by" and "right_by".')
left_join_on_names = list(map(left._internal.spark_column_name_for, to_list(by)))
right_join_on_names = list(map(right._internal.spark_column_name_for, to_list(by)))
else:
left_join_on_names = list(map(left._internal.spark_column_name_for, to_list(left_by)))
right_join_on_names = list(map(right._internal.spark_column_name_for, to_list(right_by)))
if left_join_on_names and not right_join_on_names:
raise ValueError("missing right_by")
if right_join_on_names and not left_join_on_names:
raise ValueError("missing left_by")
if len(left_join_on_names) != len(right_join_on_names):
raise ValueError("left_by and right_by must be same length")
# We should distinguish the name to avoid ambiguous column name after merging.
right_prefix = "__right_"
right_as_of_names = [right_prefix + right_as_of_name for right_as_of_name in right_as_of_names]
right_join_on_names = [
right_prefix + right_join_on_name for right_join_on_name in right_join_on_names
]
left_as_of_name = left_as_of_names[0]
right_as_of_name = right_as_of_names[0]
def resolve(internal: InternalFrame, side: str) -> InternalFrame:
rename = lambda col: "__{}_{}".format(side, col)
internal = internal.resolved_copy
sdf = internal.spark_frame
sdf = sdf.select(
*[
scol_for(sdf, col).alias(rename(col))
for col in sdf.columns
if col not in HIDDEN_COLUMNS
],
*HIDDEN_COLUMNS
)
return internal.copy(
spark_frame=sdf,
index_spark_columns=[
scol_for(sdf, rename(col)) for col in internal.index_spark_column_names
],
index_fields=[field.copy(name=rename(field.name)) for field in internal.index_fields],
data_spark_columns=[
scol_for(sdf, rename(col)) for col in internal.data_spark_column_names
],
data_fields=[field.copy(name=rename(field.name)) for field in internal.data_fields],
)
left_internal = left._internal.resolved_copy
right_internal = resolve(right._internal, "right")
left_table = left_internal.spark_frame.alias("left_table")
right_table = right_internal.spark_frame.alias("right_table")
left_as_of_column = scol_for(left_table, left_as_of_name)
right_as_of_column = scol_for(right_table, right_as_of_name)
if left_join_on_names:
left_join_on_columns = [scol_for(left_table, label) for label in left_join_on_names]
right_join_on_columns = [scol_for(right_table, label) for label in right_join_on_names]
on = reduce(
lambda l, r: l & r,
[l == r for l, r in zip(left_join_on_columns, right_join_on_columns)],
)
else:
on = None
if tolerance is not None and not isinstance(tolerance, Column):
tolerance = SF.lit(tolerance)
as_of_joined_table = left_table._joinAsOf(
right_table,
leftAsOfColumn=left_as_of_column,
rightAsOfColumn=right_as_of_column,
on=on,
how="left",
tolerance=tolerance,
allowExactMatches=allow_exact_matches,
direction=direction,
)
# Unpack suffixes tuple for convenience
left_suffix = suffixes[0]
right_suffix = suffixes[1]
# Append suffixes to columns with the same name to avoid conflicts later
duplicate_columns = set(left_internal.column_labels) & set(right_internal.column_labels)
exprs = []
data_columns = []
column_labels = []
left_scol_for = lambda label: scol_for(
as_of_joined_table, left_internal.spark_column_name_for(label)
)
right_scol_for = lambda label: scol_for(
as_of_joined_table, right_internal.spark_column_name_for(label)
)
for label in left_internal.column_labels:
col = left_internal.spark_column_name_for(label)
scol = left_scol_for(label)
if label in duplicate_columns:
spark_column_name = left_internal.spark_column_name_for(label)
if spark_column_name in (left_as_of_names + left_join_on_names) and (
(right_prefix + spark_column_name) in (right_as_of_names + right_join_on_names)
):
pass
else:
col = col + left_suffix
scol = scol.alias(col)
label = tuple([str(label[0]) + left_suffix] + list(label[1:]))
exprs.append(scol)
data_columns.append(col)
column_labels.append(label)
for label in right_internal.column_labels:
# recover `right_prefix` here.
col = right_internal.spark_column_name_for(label)[len(right_prefix) :]
scol = right_scol_for(label).alias(col)
if label in duplicate_columns:
spark_column_name = left_internal.spark_column_name_for(label)
if spark_column_name in left_as_of_names + left_join_on_names and (
(right_prefix + spark_column_name) in right_as_of_names + right_join_on_names
):
continue
else:
col = col + right_suffix
scol = scol.alias(col)
label = tuple([str(label[0]) + right_suffix] + list(label[1:]))
exprs.append(scol)
data_columns.append(col)
column_labels.append(label)
# Retain indices if they are used for joining
if left_index or right_index:
index_spark_column_names = [
SPARK_INDEX_NAME_FORMAT(i) for i in range(len(left_internal.index_spark_column_names))
]
left_index_scols = [
scol.alias(name)
for scol, name in zip(left_internal.index_spark_columns, index_spark_column_names)
]
exprs.extend(left_index_scols)
index_names = left_internal.index_names
else:
index_spark_column_names = []
index_names = []
selected_columns = as_of_joined_table.select(*exprs)
internal = InternalFrame(
spark_frame=selected_columns,
index_spark_columns=[scol_for(selected_columns, col) for col in index_spark_column_names],
index_names=index_names,
column_labels=column_labels,
data_spark_columns=[scol_for(selected_columns, col) for col in data_columns],
)
return DataFrame(internal)
@no_type_check
def to_numeric(arg, errors="raise"):
"""

View file

@ -24,6 +24,7 @@ import pandas as pd
from pyspark import pandas as ps
from pyspark.pandas.utils import name_like_string
from pyspark.sql.utils import AnalysisException
from pyspark.testing.pandasutils import PandasOnSparkTestCase
@ -283,6 +284,143 @@ class ReshapeTest(PandasOnSparkTestCase):
pd.get_dummies(pdf, columns=("x", 1), dtype=np.int8).rename(columns=name_like_string),
)
def test_merge_asof(self):
pdf_left = pd.DataFrame(
{"a": [1, 5, 10], "b": ["x", "y", "z"], "left_val": ["a", "b", "c"]}, index=[10, 20, 30]
)
pdf_right = pd.DataFrame(
{"a": [1, 2, 3, 6, 7], "b": ["v", "w", "x", "y", "z"], "right_val": [1, 2, 3, 6, 7]},
index=[100, 101, 102, 103, 104],
)
psdf_left = ps.from_pandas(pdf_left)
psdf_right = ps.from_pandas(pdf_right)
self.assert_eq(
pd.merge_asof(pdf_left, pdf_right, on="a").sort_values("a").reset_index(drop=True),
ps.merge_asof(psdf_left, psdf_right, on="a").sort_values("a").reset_index(drop=True),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, left_on="a", right_on="a")
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, left_on="a", right_on="a")
.sort_values("a")
.reset_index(drop=True)
),
)
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
self.assert_eq(
pd.merge_asof(
pdf_left.set_index("a"), pdf_right, left_index=True, right_on="a"
).sort_index(),
ps.merge_asof(
psdf_left.set_index("a"), psdf_right, left_index=True, right_on="a"
).sort_index(),
)
else:
expected = pd.DataFrame(
{
"b_x": ["x", "y", "z"],
"left_val": ["a", "b", "c"],
"a": [1, 3, 7],
"b_y": ["v", "x", "z"],
"right_val": [1, 3, 7],
},
index=pd.Index([1, 5, 10], name="a"),
)
self.assert_eq(
expected,
ps.merge_asof(
psdf_left.set_index("a"), psdf_right, left_index=True, right_on="a"
).sort_index(),
)
self.assert_eq(
pd.merge_asof(
pdf_left, pdf_right.set_index("a"), left_on="a", right_index=True
).sort_index(),
ps.merge_asof(
psdf_left, psdf_right.set_index("a"), left_on="a", right_index=True
).sort_index(),
)
self.assert_eq(
pd.merge_asof(
pdf_left.set_index("a"), pdf_right.set_index("a"), left_index=True, right_index=True
).sort_index(),
ps.merge_asof(
psdf_left.set_index("a"),
psdf_right.set_index("a"),
left_index=True,
right_index=True,
).sort_index(),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, on="a", by="b")
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, on="a", by="b")
.sort_values("a")
.reset_index(drop=True)
),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, on="a", tolerance=1)
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=1)
.sort_values("a")
.reset_index(drop=True)
),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, on="a", allow_exact_matches=False)
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, on="a", allow_exact_matches=False)
.sort_values("a")
.reset_index(drop=True)
),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, on="a", direction="forward")
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, on="a", direction="forward")
.sort_values("a")
.reset_index(drop=True)
),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, on="a", direction="nearest")
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, on="a", direction="nearest")
.sort_values("a")
.reset_index(drop=True)
),
)
self.assertRaises(
AnalysisException, lambda: ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=-1)
)
if __name__ == "__main__":
import unittest

View file

@ -1357,6 +1357,123 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.join(other._jdf, on, how)
return DataFrame(jdf, self.sql_ctx)
# TODO(SPARK-22947): Fix the DataFrame API.
def _joinAsOf(
self,
other,
leftAsOfColumn,
rightAsOfColumn,
on=None,
how=None,
*,
tolerance=None,
allowExactMatches=True,
direction="backward",
):
"""
Perform an as-of join.
This is similar to a left-join except that we match on nearest
key rather than equal keys.
.. versionadded:: 3.3.0
Parameters
----------
other : :class:`DataFrame`
Right side of the join
leftAsOfColumn : str or :class:`Column`
a string for the as-of join column name, or a Column
rightAsOfColumn : str or :class:`Column`
a string for the as-of join column name, or a Column
on : str, list or :class:`Column`, optional
a string for the join column name, a list of column names,
a join expression (Column), or a list of Columns.
If `on` is a string or a list of strings indicating the name of the join column(s),
the column(s) must exist on both sides, and this performs an equi-join.
how : str, optional
default ``inner``. Must be one of: ``inner`` and ``left``.
tolerance : :class:`Column`, optional
an asof tolerance within this range; must be compatible
with the merge index.
allowExactMatches : bool, optional
default ``True``.
direction : str, optional
default ``backward``. Must be one of: ``backward``, ``forward``, and ``nearest``.
Examples
--------
The following performs an as-of join between ``left`` and ``right``.
>>> left = spark.createDataFrame([(1, "a"), (5, "b"), (10, "c")], ["a", "left_val"])
>>> right = spark.createDataFrame([(1, 1), (2, 2), (3, 3), (6, 6), (7, 7)],
... ["a", "right_val"])
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a"
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1),
Row(a=5, left_val='b', right_val=3),
Row(a=10, left_val='c', right_val=7)]
>>> from pyspark.sql import functions as F
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", tolerance=F.lit(1)
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1)]
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", how="left", tolerance=F.lit(1)
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1),
Row(a=5, left_val='b', right_val=None),
Row(a=10, left_val='c', right_val=None)]
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", allowExactMatches=False
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=5, left_val='b', right_val=3),
Row(a=10, left_val='c', right_val=7)]
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", direction="forward"
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1),
Row(a=5, left_val='b', right_val=6)]
"""
if isinstance(leftAsOfColumn, str):
leftAsOfColumn = self[leftAsOfColumn]
left_as_of_jcol = leftAsOfColumn._jc
if isinstance(rightAsOfColumn, str):
rightAsOfColumn = other[rightAsOfColumn]
right_as_of_jcol = rightAsOfColumn._jc
if on is not None and not isinstance(on, list):
on = [on]
if on is not None:
if isinstance(on[0], str):
on = self._jseq(on)
else:
assert isinstance(on[0], Column), "on should be Column or list of Column"
on = reduce(lambda x, y: x.__and__(y), on)
on = on._jc
if how is None:
how = "inner"
assert isinstance(how, str), "how should be a string"
if tolerance is not None:
assert isinstance(tolerance, Column), "tolerance should be Column"
tolerance = tolerance._jc
jdf = self._jdf.joinAsOf(
other._jdf,
left_as_of_jcol, right_as_of_jcol,
on,
how, tolerance, allowExactMatches, direction
)
return DataFrame(jdf, self.sql_ctx)
def sortWithinPartitions(self, *cols, **kwargs):
"""Returns a new :class:`DataFrame` with each partition sorted by the specified column(s).

View file

@ -249,6 +249,20 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
s"join condition '${condition.sql}' " +
s"of type ${condition.dataType.catalogString} is not a boolean.")
case j @ AsOfJoin(_, _, _, Some(condition), _, _, _)
if condition.dataType != BooleanType =>
failAnalysis(
s"join condition '${condition.sql}' " +
s"of type ${condition.dataType.catalogString} is not a boolean.")
case j @ AsOfJoin(_, _, _, _, _, _, Some(toleranceAssertion)) =>
if (!toleranceAssertion.foldable) {
failAnalysis("Input argument tolerance must be a constant.")
}
if (!toleranceAssertion.eval().asInstanceOf[Boolean]) {
failAnalysis("Input argument tolerance must be non-negative.")
}
case a @ Aggregate(groupingExprs, aggregateExprs, child) =>
def isAggregateExpression(expr: Expression): Boolean = {
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
@ -506,6 +520,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
""".stripMargin)
case j: AsOfJoin if !j.duplicateResolved =>
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
failAnalysis(
s"""
|Failure when resolving conflicting references in AsOfJoin:
|$plan
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)
// TODO: although map type is not orderable, technically map type should be able to be
// used in equality comparison, remove this type check once we support it.
case o if mapColumnInSetOperation(o).isDefined =>

View file

@ -41,7 +41,8 @@ case class ReferenceEqualPlanWrapper(plan: LogicalPlan) {
object DeduplicateRelations extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
renewDuplicatedRelations(mutable.HashSet.empty, plan)._1.resolveOperatorsUpWithPruning(
_.containsAnyPattern(JOIN, LATERAL_JOIN, INTERSECT, EXCEPT, UNION, COMMAND), ruleId) {
_.containsAnyPattern(JOIN, LATERAL_JOIN, AS_OF_JOIN, INTERSECT, EXCEPT, UNION, COMMAND),
ruleId) {
case p: LogicalPlan if !p.childrenResolved => p
// To resolve duplicate expression IDs for Join.
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
@ -49,6 +50,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
// Resolve duplicate output for LateralJoin.
case j @ LateralJoin(left, right, _, _) if right.resolved && !j.duplicateResolved =>
j.copy(right = right.withNewPlan(dedupRight(left, right.plan)))
// Resolve duplicate output for AsOfJoin.
case j @ AsOfJoin(left, right, _, _, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
// intersect/except will be rewritten to join at the beginning of optimizer. Here we need to
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
case i @ Intersect(left, right, _) if !i.duplicateResolved =>

View file

@ -159,7 +159,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
PullOutGroupingExpressions,
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
SpecialDatetimeValues) ::
SpecialDatetimeValues,
RewriteAsOfJoin) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
@ -282,7 +283,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceUpdateFieldsExpression.ruleName ::
PullOutGroupingExpressions.ruleName :: Nil
PullOutGroupingExpressions.ruleName ::
RewriteAsOfJoin.ruleName :: Nil
/**
* Optimize all the subqueries inside expression.

View file

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
/**
* Replaces logical [[AsOfJoin]] operator using a combination of Join and Aggregate operator.
*
* Input Pseudo-Query:
* {{{
* SELECT * FROM left ASOF JOIN right ON (condition, as_of on(left.t, right.t), tolerance)
* }}}
*
* Rewritten Query:
* {{{
* SELECT left.*, __right__.*
* FROM (
* SELECT
* left.*,
* (
* SELECT MIN_BY(STRUCT(right.*), left.t - right.t) AS __nearest_right__
* FROM right
* WHERE condition AND left.t >= right.t AND right.t >= left.t - tolerance
* ) as __right__
* FROM left
* )
* WHERE __right__ IS NOT NULL
* }}}
*/
object RewriteAsOfJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(AS_OF_JOIN), ruleId) {
case AsOfJoin(left, right, asOfCondition, condition, joinType, orderExpression, _) =>
val conditionWithOuterReference =
condition.map(And(_, asOfCondition)).getOrElse(asOfCondition).transformUp {
case a: AttributeReference if left.outputSet.contains(a) =>
OuterReference(a)
}
val filtered = Filter(conditionWithOuterReference, right)
val orderExpressionWithOuterReference = orderExpression.transformUp {
case a: AttributeReference if left.outputSet.contains(a) =>
OuterReference(a)
}
val rightStruct = CreateStruct(right.output)
val nearestRight = MinBy(rightStruct, orderExpressionWithOuterReference)
.toAggregateExpression()
val aggExpr = Alias(nearestRight, "__nearest_right__")()
val aggregate = Aggregate(Seq.empty, Seq(aggExpr), filtered)
val projectWithScalarSubquery = Project(
left.output :+ Alias(ScalarSubquery(aggregate, left.output), "__right__")(),
left)
val filterRight = joinType match {
case LeftOuter => projectWithScalarSubquery
case _ =>
Filter(IsNotNull(projectWithScalarSubquery.output.last), projectWithScalarSubquery)
}
Project(
left.output ++ right.output.zipWithIndex.map {
case (out, idx) =>
Alias(GetStructField(filterRight.output.last, idx), out.name)(exprId = out.exprId)
},
filterRight)
}
}

View file

@ -121,3 +121,24 @@ object LeftSemiOrAnti {
case _ => None
}
}
object AsOfJoinDirection {
def apply(direction: String): AsOfJoinDirection = {
direction.toLowerCase(Locale.ROOT) match {
case "forward" => Forward
case "backward" => Backward
case "nearest" => Nearest
case _ =>
val supported = Seq("forward", "backward", "nearest")
throw new IllegalArgumentException(s"Unsupported as-of join direction '$direction'. " +
"Supported as-of join direction include: " + supported.mkString("'", "', '", "'") + ".")
}
}
}
sealed abstract class AsOfJoinDirection
case object Forward extends AsOfJoinDirection
case object Backward extends AsOfJoinDirection
case object Nearest extends AsOfJoinDirection

View file

@ -1597,3 +1597,119 @@ case class LateralJoin(
copy(left = newChild)
}
}
/**
* A logical plan for as-of join.
*/
case class AsOfJoin(
left: LogicalPlan,
right: LogicalPlan,
asOfCondition: Expression,
condition: Option[Expression],
joinType: JoinType,
orderExpression: Expression,
toleranceAssertion: Option[Expression]) extends BinaryNode {
require(Seq(Inner, LeftOuter).contains(joinType),
s"Unsupported as-of join type $joinType")
override protected def stringArgs: Iterator[Any] = super.stringArgs.take(5)
override def output: Seq[Attribute] = {
joinType match {
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
case _ =>
left.output ++ right.output
}
}
def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
override lazy val resolved: Boolean = {
childrenResolved &&
expressions.forall(_.resolved) &&
duplicateResolved &&
asOfCondition.dataType == BooleanType &&
condition.forall(_.dataType == BooleanType) &&
toleranceAssertion.forall { assertion =>
assertion.foldable && assertion.eval().asInstanceOf[Boolean]
}
}
final override val nodePatterns: Seq[TreePattern] = Seq(AS_OF_JOIN)
override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): AsOfJoin = {
copy(left = newLeft, right = newRight)
}
}
object AsOfJoin {
def apply(
left: LogicalPlan,
right: LogicalPlan,
leftAsOf: Expression,
rightAsOf: Expression,
condition: Option[Expression],
joinType: JoinType,
tolerance: Option[Expression],
allowExactMatches: Boolean,
direction: AsOfJoinDirection): AsOfJoin = {
val asOfCond = makeAsOfCond(leftAsOf, rightAsOf, tolerance, allowExactMatches, direction)
val orderingExpr = makeOrderingExpr(leftAsOf, rightAsOf, direction)
AsOfJoin(left, right, asOfCond, condition, joinType,
orderingExpr, tolerance.map(t => GreaterThanOrEqual(t, Literal.default(t.dataType))))
}
private def makeAsOfCond(
leftAsOf: Expression,
rightAsOf: Expression,
tolerance: Option[Expression],
allowExactMatches: Boolean,
direction: AsOfJoinDirection): Expression = {
val base = (allowExactMatches, direction) match {
case (true, Backward) => GreaterThanOrEqual(leftAsOf, rightAsOf)
case (false, Backward) => GreaterThan(leftAsOf, rightAsOf)
case (true, Forward) => LessThanOrEqual(leftAsOf, rightAsOf)
case (false, Forward) => LessThan(leftAsOf, rightAsOf)
case (true, Nearest) => Literal.TrueLiteral
case (false, Nearest) => Not(EqualTo(leftAsOf, rightAsOf))
}
tolerance match {
case Some(tolerance) =>
(allowExactMatches, direction) match {
case (true, Backward) =>
And(base, GreaterThanOrEqual(rightAsOf, Subtract(leftAsOf, tolerance)))
case (false, Backward) =>
And(base, GreaterThan(rightAsOf, Subtract(leftAsOf, tolerance)))
case (true, Forward) =>
And(base, LessThanOrEqual(rightAsOf, Add(leftAsOf, tolerance)))
case (false, Forward) =>
And(base, LessThan(rightAsOf, Add(leftAsOf, tolerance)))
case (true, Nearest) =>
And(GreaterThanOrEqual(rightAsOf, Subtract(leftAsOf, tolerance)),
LessThanOrEqual(rightAsOf, Add(leftAsOf, tolerance)))
case (false, Nearest) =>
And(base,
And(GreaterThan(rightAsOf, Subtract(leftAsOf, tolerance)),
LessThan(rightAsOf, Add(leftAsOf, tolerance))))
}
case None => base
}
}
private def makeOrderingExpr(
leftAsOf: Expression,
rightAsOf: Expression,
direction: AsOfJoinDirection): Expression = {
direction match {
case Backward => Subtract(leftAsOf, rightAsOf)
case Forward => Subtract(rightAsOf, leftAsOf)
case Nearest =>
If(GreaterThan(leftAsOf, rightAsOf),
Subtract(leftAsOf, rightAsOf), Subtract(rightAsOf, leftAsOf))
}
}
}

View file

@ -144,6 +144,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.ReplaceIntersectWithSemiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.RewriteExceptAll" ::
"org.apache.spark.sql.catalyst.optimizer.RewriteIntersectAll" ::
"org.apache.spark.sql.catalyst.optimizer.RewriteAsOfJoin" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" ::

View file

@ -91,6 +91,7 @@ object TreePattern extends Enumeration {
// Logical plan patterns (alphabetically ordered)
val AGGREGATE: Value = Value
val AS_OF_JOIN: Value = Value
val COMMAND: Value = Value
val CTE: Value = Value
val DISTINCT_LIKE: Value = Value

View file

@ -0,0 +1,289 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{CreateStruct, GetStructField, If, OuterReference, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.aggregate.MinBy
import org.apache.spark.sql.catalyst.plans.{AsOfJoinDirection, Inner, LeftOuter, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{AsOfJoin, LocalRelation}
class RewriteAsOfJoinSuite extends PlanTest {
test("simple") {
val left = LocalRelation('a.int, 'b.int, 'c.int)
val right = LocalRelation('a.int, 'b.int, 'd.int)
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("backward"))
val rewritten = RewriteAsOfJoin(query.analyze)
val filter = OuterReference(left.output(0)) >= right.output(0)
val rightStruct = CreateStruct(right.output)
val orderExpression = OuterReference(left.output(0)) - right.output(0)
val nearestRight = MinBy(rightStruct, orderExpression)
.toAggregateExpression().as("__nearest_right__")
val scalarSubquery = left.select(
left.output :+ ScalarSubquery(
right.where(filter).groupBy()(nearestRight),
left.output).as("__right__"): _*)
val correctAnswer = scalarSubquery
.where(scalarSubquery.output.last.isNotNull)
.select(left.output :+
GetStructField(scalarSubquery.output.last, 0).as("a") :+
GetStructField(scalarSubquery.output.last, 1).as("b") :+
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
}
test("condition") {
val left = LocalRelation('a.int, 'b.int, 'c.int)
val right = LocalRelation('a.int, 'b.int, 'd.int)
val query = AsOfJoin(left, right, left.output(0), right.output(0),
Some(left.output(1) === right.output(1)), Inner,
tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("backward"))
val rewritten = RewriteAsOfJoin(query.analyze)
val filter = OuterReference(left.output(1)) === right.output(1) &&
OuterReference(left.output(0)) >= right.output(0)
val rightStruct = CreateStruct(right.output)
val orderExpression = OuterReference(left.output(0)) - right.output(0)
val nearestRight = MinBy(rightStruct, orderExpression)
.toAggregateExpression().as("__nearest_right__")
val scalarSubquery = left.select(
left.output :+ ScalarSubquery(
right.where(filter).groupBy()(nearestRight),
left.output).as("__right__"): _*)
val correctAnswer = scalarSubquery
.where(scalarSubquery.output.last.isNotNull)
.select(left.output :+
GetStructField(scalarSubquery.output.last, 0).as("a") :+
GetStructField(scalarSubquery.output.last, 1).as("b") :+
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
}
test("left outer") {
val left = LocalRelation('a.int, 'b.int, 'c.int)
val right = LocalRelation('a.int, 'b.int, 'd.int)
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("backward"))
val rewritten = RewriteAsOfJoin(query.analyze)
val filter = OuterReference(left.output(0)) >= right.output(0)
val rightStruct = CreateStruct(right.output)
val orderExpression = OuterReference(left.output(0)) - right.output(0)
val nearestRight = MinBy(rightStruct, orderExpression)
.toAggregateExpression().as("__nearest_right__")
val scalarSubquery = left.select(
left.output :+ ScalarSubquery(
right.where(filter).groupBy()(nearestRight),
left.output).as("__right__"): _*)
val correctAnswer = scalarSubquery
.where(scalarSubquery.output.last.isNotNull)
.select(left.output :+
GetStructField(scalarSubquery.output.last, 0).as("a") :+
GetStructField(scalarSubquery.output.last, 1).as("b") :+
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
}
test("tolerance") {
val left = LocalRelation('a.int, 'b.int, 'c.int)
val right = LocalRelation('a.int, 'b.int, 'd.int)
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
tolerance = Some(1), allowExactMatches = true, direction = AsOfJoinDirection("backward"))
val rewritten = RewriteAsOfJoin(query.analyze)
val filter = OuterReference(left.output(0)) >= right.output(0) &&
right.output(0) >= OuterReference(left.output(0)) - 1
val rightStruct = CreateStruct(right.output)
val orderExpression = OuterReference(left.output(0)) - right.output(0)
val nearestRight = MinBy(rightStruct, orderExpression)
.toAggregateExpression().as("__nearest_right__")
val scalarSubquery = left.select(
left.output :+ ScalarSubquery(
right.where(filter).groupBy()(nearestRight),
left.output).as("__right__"): _*)
val correctAnswer = scalarSubquery
.where(scalarSubquery.output.last.isNotNull)
.select(left.output :+
GetStructField(scalarSubquery.output.last, 0).as("a") :+
GetStructField(scalarSubquery.output.last, 1).as("b") :+
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
}
test("allowExactMatches = false") {
val left = LocalRelation('a.int, 'b.int, 'c.int)
val right = LocalRelation('a.int, 'b.int, 'd.int)
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, LeftOuter,
tolerance = None, allowExactMatches = false, direction = AsOfJoinDirection("backward"))
val rewritten = RewriteAsOfJoin(query.analyze)
val filter = OuterReference(left.output(0)) > right.output(0)
val rightStruct = CreateStruct(right.output)
val orderExpression = OuterReference(left.output(0)) - right.output(0)
val nearestRight = MinBy(rightStruct, orderExpression)
.toAggregateExpression().as("__nearest_right__")
val scalarSubquery = left.select(
left.output :+ ScalarSubquery(
right.where(filter).groupBy()(nearestRight),
left.output).as("__right__"): _*)
val correctAnswer = scalarSubquery
.select(left.output :+
GetStructField(scalarSubquery.output.last, 0).as("a") :+
GetStructField(scalarSubquery.output.last, 1).as("b") :+
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
}
test("tolerance & allowExactMatches = false") {
val left = LocalRelation('a.int, 'b.int, 'c.int)
val right = LocalRelation('a.int, 'b.int, 'd.int)
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
tolerance = Some(1), allowExactMatches = false, direction = AsOfJoinDirection("backward"))
val rewritten = RewriteAsOfJoin(query.analyze)
val filter = OuterReference(left.output(0)) > right.output(0) &&
right.output(0) > OuterReference(left.output(0)) - 1
val rightStruct = CreateStruct(right.output)
val orderExpression = OuterReference(left.output(0)) - right.output(0)
val nearestRight = MinBy(rightStruct, orderExpression)
.toAggregateExpression().as("__nearest_right__")
val scalarSubquery = left.select(
left.output :+ ScalarSubquery(
right.where(filter).groupBy()(nearestRight),
left.output).as("__right__"): _*)
val correctAnswer = scalarSubquery
.where(scalarSubquery.output.last.isNotNull)
.select(left.output :+
GetStructField(scalarSubquery.output.last, 0).as("a") :+
GetStructField(scalarSubquery.output.last, 1).as("b") :+
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
}
test("direction = forward") {
val left = LocalRelation('a.int, 'b.int, 'c.int)
val right = LocalRelation('a.int, 'b.int, 'd.int)
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("forward"))
val rewritten = RewriteAsOfJoin(query.analyze)
val filter = OuterReference(left.output(0)) <= right.output(0)
val rightStruct = CreateStruct(right.output)
val orderExpression = right.output(0) - OuterReference(left.output(0))
val nearestRight = MinBy(rightStruct, orderExpression)
.toAggregateExpression().as("__nearest_right__")
val scalarSubquery = left.select(
left.output :+ ScalarSubquery(
right.where(filter).groupBy()(nearestRight),
left.output).as("__right__"): _*)
val correctAnswer = scalarSubquery
.where(scalarSubquery.output.last.isNotNull)
.select(left.output :+
GetStructField(scalarSubquery.output.last, 0).as("a") :+
GetStructField(scalarSubquery.output.last, 1).as("b") :+
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
}
test("direction = nearest") {
val left = LocalRelation('a.int, 'b.int, 'c.int)
val right = LocalRelation('a.int, 'b.int, 'd.int)
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("nearest"))
val rewritten = RewriteAsOfJoin(query.analyze)
val filter = true
val rightStruct = CreateStruct(right.output)
val orderExpression = If(OuterReference(left.output(0)) > right.output(0),
OuterReference(left.output(0)) - right.output(0),
right.output(0) - OuterReference(left.output(0)))
val nearestRight = MinBy(rightStruct, orderExpression)
.toAggregateExpression().as("__nearest_right__")
val scalarSubquery = left.select(
left.output :+ ScalarSubquery(
right.where(filter).groupBy()(nearestRight),
left.output).as("__right__"): _*)
val correctAnswer = scalarSubquery
.where(scalarSubquery.output.last.isNotNull)
.select(left.output :+
GetStructField(scalarSubquery.output.last, 0).as("a") :+
GetStructField(scalarSubquery.output.last, 1).as("b") :+
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
}
test("tolerance & allowExactMatches = false & direction = nearest") {
val left = LocalRelation('a.int, 'b.int, 'c.int)
val right = LocalRelation('a.int, 'b.int, 'd.int)
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
tolerance = Some(1), allowExactMatches = false, direction = AsOfJoinDirection("nearest"))
val rewritten = RewriteAsOfJoin(query.analyze)
val filter = (!(OuterReference(left.output(0)) === right.output(0))) &&
((right.output(0) > OuterReference(left.output(0)) - 1) &&
(right.output(0) < OuterReference(left.output(0)) + 1))
val rightStruct = CreateStruct(right.output)
val orderExpression = If(OuterReference(left.output(0)) > right.output(0),
OuterReference(left.output(0)) - right.output(0),
right.output(0) - OuterReference(left.output(0)))
val nearestRight = MinBy(rightStruct, orderExpression)
.toAggregateExpression().as("__nearest_right__")
val scalarSubquery = left.select(
left.output :+ ScalarSubquery(
right.where(filter).groupBy()(nearestRight),
left.output).as("__right__"): _*)
val correctAnswer = scalarSubquery
.where(scalarSubquery.output.last.isNotNull)
.select(left.output :+
GetStructField(scalarSubquery.output.last, 0).as("a") :+
GetStructField(scalarSubquery.output.last, 1).as("b") :+
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
}
}

View file

@ -1046,6 +1046,47 @@ class Dataset[T] private[sql](
plan.copy(condition = cond)
}
/**
* find the trivially true predicates and automatically resolves them to both sides.
*/
private def resolveSelfJoinCondition(
right: Dataset[_],
joinExprs: Option[Column],
joinType: String): Join = {
// Note that in this function, we introduce a hack in the case of self-join to automatically
// resolve ambiguous join conditions into ones that might make sense [SPARK-6231].
// Consider this case: df.join(df, df("key") === df("key"))
// Since df("key") === df("key") is a trivially true condition, this actually becomes a
// cartesian join. However, most likely users expect to perform a self join using "key".
// With that assumption, this hack turns the trivially true condition into equality on join
// keys that are resolved to both sides.
// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
// After the cloning, left and right side will have distinct expression ids.
val plan = withPlan(
Join(logicalPlan, right.logicalPlan,
JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE))
.queryExecution.analyzed.asInstanceOf[Join]
// If auto self join alias is disabled, return the plan.
if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
return plan
}
// If left/right have no output set intersection, return the plan.
val lanalyzed = this.queryExecution.analyzed
val ranalyzed = right.queryExecution.analyzed
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
return plan
}
// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
// By the time we get here, since we have already run analysis, all attributes should've been
// resolved and become AttributeReference.
resolveSelfJoinCondition(plan)
}
/**
* Join with another `DataFrame`, using the given join expression. The following performs
* a full outer join between `df1` and `df2`.
@ -1071,38 +1112,8 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = {
// Note that in this function, we introduce a hack in the case of self-join to automatically
// resolve ambiguous join conditions into ones that might make sense [SPARK-6231].
// Consider this case: df.join(df, df("key") === df("key"))
// Since df("key") === df("key") is a trivially true condition, this actually becomes a
// cartesian join. However, most likely users expect to perform a self join using "key".
// With that assumption, this hack turns the trivially true condition into equality on join
// keys that are resolved to both sides.
// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
// After the cloning, left and right side will have distinct expression ids.
val plan = withPlan(
Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE))
.queryExecution.analyzed.asInstanceOf[Join]
// If auto self join alias is disabled, return the plan.
if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
return withPlan(plan)
}
// If left/right have no output set intersection, return the plan.
val lanalyzed = this.queryExecution.analyzed
val ranalyzed = right.queryExecution.analyzed
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
return withPlan(plan)
}
// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
// By the time we get here, since we have already run analysis, all attributes should've been
// resolved and become AttributeReference.
withPlan {
resolveSelfJoinCondition(plan)
resolveSelfJoinCondition(right, Some(joinExprs), joinType)
}
}
@ -1232,6 +1243,58 @@ class Dataset[T] private[sql](
joinWith(other, condition, "inner")
}
// TODO(SPARK-22947): Fix the DataFrame API.
private[sql] def joinAsOf(
other: Dataset[_],
leftAsOf: Column,
rightAsOf: Column,
usingColumns: Seq[String],
joinType: String,
tolerance: Column,
allowExactMatches: Boolean,
direction: String): DataFrame = {
val joinExprs = usingColumns.map { column =>
EqualTo(resolve(column), other.resolve(column))
}.reduceOption(And).map(Column.apply).orNull
joinAsOf(other, leftAsOf, rightAsOf, joinExprs, joinType,
tolerance, allowExactMatches, direction)
}
// TODO(SPARK-22947): Fix the DataFrame API.
private[sql] def joinAsOf(
other: Dataset[_],
leftAsOf: Column,
rightAsOf: Column,
joinExprs: Column,
joinType: String,
tolerance: Column,
allowExactMatches: Boolean,
direction: String): DataFrame = {
val joined = resolveSelfJoinCondition(other, Option(joinExprs), joinType)
val leftAsOfExpr = leftAsOf.expr.transformUp {
case a: AttributeReference if logicalPlan.outputSet.contains(a) =>
val index = logicalPlan.output.indexWhere(_.exprId == a.exprId)
joined.left.output(index)
}
val rightAsOfExpr = rightAsOf.expr.transformUp {
case a: AttributeReference if other.logicalPlan.outputSet.contains(a) =>
val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId)
joined.right.output(index)
}
withPlan {
AsOfJoin(
joined.left, joined.right,
leftAsOfExpr, rightAsOfExpr,
joined.condition,
joined.joinType,
Option(tolerance).map(_.expr),
allowExactMatches,
AsOfJoinDirection(direction)
)
}
}
/**
* Returns a new Dataset with each partition sorted by the given expressions.
*

View file

@ -0,0 +1,169 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import scala.collection.JavaConverters._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
class DataFrameAsOfJoinSuite extends QueryTest
with SharedSparkSession
with AdaptiveSparkPlanHelper {
def prepareForAsOfJoin(): (DataFrame, DataFrame) = {
val schema1 = StructType(
StructField("a", IntegerType, false) ::
StructField("b", StringType, false) ::
StructField("left_val", StringType, false) :: Nil)
val rowSeq1: List[Row] = List(Row(1, "x", "a"), Row(5, "y", "b"), Row(10, "z", "c"))
val df1 = spark.createDataFrame(rowSeq1.asJava, schema1)
val schema2 = StructType(
StructField("a", IntegerType) ::
StructField("b", StringType) ::
StructField("right_val", IntegerType) :: Nil)
val rowSeq2: List[Row] = List(Row(1, "v", 1), Row(2, "w", 2), Row(3, "x", 3),
Row(6, "y", 6), Row(7, "z", 7))
val df2 = spark.createDataFrame(rowSeq2.asJava, schema2)
(df1, df2)
}
test("as-of join - simple") {
val (df1, df2) = prepareForAsOfJoin()
checkAnswer(
df1.joinAsOf(
df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
joinType = "inner", tolerance = null, allowExactMatches = true, direction = "backward"),
Seq(
Row(1, "x", "a", 1, "v", 1),
Row(5, "y", "b", 3, "x", 3),
Row(10, "z", "c", 7, "z", 7)
)
)
}
test("as-of join - usingColumns") {
val (df1, df2) = prepareForAsOfJoin()
checkAnswer(
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"),
joinType = "inner", tolerance = null, allowExactMatches = true, direction = "backward"),
Seq(
Row(10, "z", "c", 7, "z", 7)
)
)
}
test("as-of join - usingColumns, left outer") {
val (df1, df2) = prepareForAsOfJoin()
checkAnswer(
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"),
joinType = "left", tolerance = null, allowExactMatches = true, direction = "backward"),
Seq(
Row(1, "x", "a", null, null, null),
Row(5, "y", "b", null, null, null),
Row(10, "z", "c", 7, "z", 7)
)
)
}
test("as-of join - tolerance = 1") {
val (df1, df2) = prepareForAsOfJoin()
checkAnswer(
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
joinType = "inner", tolerance = lit(1), allowExactMatches = true, direction = "backward"),
Seq(
Row(1, "x", "a", 1, "v", 1)
)
)
}
test("as-of join - tolerance should be a constant") {
val (df1, df2) = prepareForAsOfJoin()
val errMsg = intercept[AnalysisException] {
df1.joinAsOf(
df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
joinType = "inner", tolerance = df1.col("b"), allowExactMatches = true,
direction = "backward")
}.getMessage
assert(errMsg.contains("Input argument tolerance must be a constant."))
}
test("as-of join - tolerance should be non-negative") {
val (df1, df2) = prepareForAsOfJoin()
val errMsg = intercept[AnalysisException] {
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
joinType = "inner", tolerance = lit(-1), allowExactMatches = true, direction = "backward")
}.getMessage
assert(errMsg.contains("Input argument tolerance must be non-negative."))
}
test("as-of join - allowExactMatches = false") {
val (df1, df2) = prepareForAsOfJoin()
checkAnswer(
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
joinType = "inner", tolerance = null, allowExactMatches = false, direction = "backward"),
Seq(
Row(5, "y", "b", 3, "x", 3),
Row(10, "z", "c", 7, "z", 7)
)
)
}
test("as-of join - direction = \"forward\"") {
val (df1, df2) = prepareForAsOfJoin()
checkAnswer(
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
joinType = "inner", tolerance = null, allowExactMatches = true, direction = "forward"),
Seq(
Row(1, "x", "a", 1, "v", 1),
Row(5, "y", "b", 6, "y", 6)
)
)
}
test("as-of join - direction = \"nearest\"") {
val (df1, df2) = prepareForAsOfJoin()
checkAnswer(
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
joinType = "inner", tolerance = null, allowExactMatches = true, direction = "nearest"),
Seq(
Row(1, "x", "a", 1, "v", 1),
Row(5, "y", "b", 6, "y", 6),
Row(10, "z", "c", 7, "z", 7)
)
)
}
test("as-of join - self") {
val (df1, _) = prepareForAsOfJoin()
checkAnswer(
df1.joinAsOf(
df1, df1.col("a"), df1.col("a"), usingColumns = Seq.empty,
joinType = "left", tolerance = null, allowExactMatches = false, direction = "nearest"),
Seq(
Row(1, "x", "a", 5, "y", "b"),
Row(5, "y", "b", 1, "x", "a"),
Row(10, "z", "c", 5, "y", "b")
)
)
}
}