[SPARK-36813][SQL][PYTHON] Propose an infrastructure of as-of join and imlement ps.merge_asof
### What changes were proposed in this pull request? Proposes an infrastructure for as-of join and implements `ps.merge_asof` here. 1. Introduce `AsOfJoin` logical plan 2. Rewrite the plan in the optimize phase: - From something like (SQL syntax is not determied): ```sql SELECT * FROM left ASOF JOIN right ON (condition, as_of on(left.t, right.t), tolerance) ``` - To ```sql SELECT left.*, __right__.* FROM ( SELECT left.*, ( SELECT MIN_BY(STRUCT(right.*), left.t - right.t) AS __nearest_right__ FROM right WHERE condition AND left.t >= right.t AND right.t >= left.t - tolerance ) as __right__ FROM left ) WHERE __right__ IS NOT NULL ``` 3. The rewritten scalar-subquery will be handled by the existing decorrelation framework. Note: APIs on SQL DataFrames and SQL syntax are TBD (e.g., [SPARK-22947](https://issues.apache.org/jira/browse/SPARK-22947)), although there are temporary APIs added here. ### Why are the changes needed? Pandas' `merge_asof` or as-of join for SQL/DataFrame is useful for time series analysis. ### Does this PR introduce _any_ user-facing change? Yes. `ps.merge_asof` can be used. ```py >>> quotes time ticker bid ask 0 2016-05-25 13:30:00.023 GOOG 720.50 720.93 1 2016-05-25 13:30:00.023 MSFT 51.95 51.96 2 2016-05-25 13:30:00.030 MSFT 51.97 51.98 3 2016-05-25 13:30:00.041 MSFT 51.99 52.00 4 2016-05-25 13:30:00.048 GOOG 720.50 720.93 5 2016-05-25 13:30:00.049 AAPL 97.99 98.01 6 2016-05-25 13:30:00.072 GOOG 720.50 720.88 7 2016-05-25 13:30:00.075 MSFT 52.01 52.03 >>> trades time ticker price quantity 0 2016-05-25 13:30:00.023 MSFT 51.95 75 1 2016-05-25 13:30:00.038 MSFT 51.95 155 2 2016-05-25 13:30:00.048 GOOG 720.77 100 3 2016-05-25 13:30:00.048 GOOG 720.92 100 4 2016-05-25 13:30:00.048 AAPL 98.00 100 >>> ps.merge_asof( ... trades, quotes, on="time", by="ticker" ... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True) time ticker price quantity bid ask 0 2016-05-25 13:30:00.023 MSFT 51.95 75 51.95 51.96 1 2016-05-25 13:30:00.038 MSFT 51.95 155 51.97 51.98 2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN 3 2016-05-25 13:30:00.048 GOOG 720.77 100 720.50 720.93 4 2016-05-25 13:30:00.048 GOOG 720.92 100 720.50 720.93 >>> ps.merge_asof( ... trades, ... quotes, ... on="time", ... by="ticker", ... tolerance=F.expr("INTERVAL 2 MILLISECONDS") # pd.Timedelta("2ms") ... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True) time ticker price quantity bid ask 0 2016-05-25 13:30:00.023 MSFT 51.95 75 51.95 51.96 1 2016-05-25 13:30:00.038 MSFT 51.95 155 NaN NaN 2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN 3 2016-05-25 13:30:00.048 GOOG 720.77 100 720.50 720.93 4 2016-05-25 13:30:00.048 GOOG 720.92 100 720.50 720.93 ``` Note: As `IntervalType` literal is not supported yet, we have to specify the `IntervalType` value with `F.expr` as a workaround. ### How was this patch tested? Added tests. Closes #34053 from ueshin/issues/SPARK-36813/merge_asof. Authored-by: Takuya UESHIN <ueshin@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
a9b4c27f12
commit
05c0fa5738
|
@ -73,6 +73,7 @@ from pyspark.pandas.utils import (
|
|||
align_diff_frames,
|
||||
default_session,
|
||||
is_name_like_tuple,
|
||||
is_name_like_value,
|
||||
name_like_string,
|
||||
same_anchor,
|
||||
scol_for,
|
||||
|
@ -83,11 +84,13 @@ from pyspark.pandas.internal import (
|
|||
InternalFrame,
|
||||
DEFAULT_SERIES_NAME,
|
||||
HIDDEN_COLUMNS,
|
||||
SPARK_INDEX_NAME_FORMAT,
|
||||
)
|
||||
from pyspark.pandas.series import Series, first_series
|
||||
from pyspark.pandas.spark import functions as SF
|
||||
from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
|
||||
from pyspark.pandas.indexes import Index, DatetimeIndex
|
||||
from pyspark.pandas.indexes.multi import MultiIndex
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -115,6 +118,7 @@ __all__ = [
|
|||
"read_sql",
|
||||
"read_json",
|
||||
"merge",
|
||||
"merge_asof",
|
||||
"to_numeric",
|
||||
"broadcast",
|
||||
"read_orc",
|
||||
|
@ -2747,6 +2751,499 @@ def merge(
|
|||
)
|
||||
|
||||
|
||||
def merge_asof(
|
||||
left: Union[DataFrame, Series],
|
||||
right: Union[DataFrame, Series],
|
||||
on: Optional[Name] = None,
|
||||
left_on: Optional[Name] = None,
|
||||
right_on: Optional[Name] = None,
|
||||
left_index: bool = False,
|
||||
right_index: bool = False,
|
||||
by: Optional[Union[Name, List[Name]]] = None,
|
||||
left_by: Optional[Union[Name, List[Name]]] = None,
|
||||
right_by: Optional[Union[Name, List[Name]]] = None,
|
||||
suffixes: Tuple[str, str] = ("_x", "_y"),
|
||||
tolerance: Optional[Any] = None,
|
||||
allow_exact_matches: bool = True,
|
||||
direction: str = "backward",
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Perform an asof merge.
|
||||
|
||||
This is similar to a left-join except that we match on nearest
|
||||
key rather than equal keys.
|
||||
|
||||
For each row in the left DataFrame:
|
||||
|
||||
- A "backward" search selects the last row in the right DataFrame whose
|
||||
'on' key is less than or equal to the left's key.
|
||||
|
||||
- A "forward" search selects the first row in the right DataFrame whose
|
||||
'on' key is greater than or equal to the left's key.
|
||||
|
||||
- A "nearest" search selects the row in the right DataFrame whose 'on'
|
||||
key is closest in absolute distance to the left's key.
|
||||
|
||||
Optionally match on equivalent keys with 'by' before searching with 'on'.
|
||||
|
||||
.. versionadded:: 3.3.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
left : DataFrame or named Series
|
||||
right : DataFrame or named Series
|
||||
on : label
|
||||
Field name to join on. Must be found in both DataFrames.
|
||||
The data MUST be ordered. Furthermore this must be a numeric column,
|
||||
such as datetimelike, integer, or float. On or left_on/right_on
|
||||
must be given.
|
||||
left_on : label
|
||||
Field name to join on in left DataFrame.
|
||||
right_on : label
|
||||
Field name to join on in right DataFrame.
|
||||
left_index : bool
|
||||
Use the index of the left DataFrame as the join key.
|
||||
right_index : bool
|
||||
Use the index of the right DataFrame as the join key.
|
||||
by : column name or list of column names
|
||||
Match on these columns before performing merge operation.
|
||||
left_by : column name
|
||||
Field names to match on in the left DataFrame.
|
||||
right_by : column name
|
||||
Field names to match on in the right DataFrame.
|
||||
suffixes : 2-length sequence (tuple, list, ...)
|
||||
Suffix to apply to overlapping column names in the left and right
|
||||
side, respectively.
|
||||
tolerance : int or Timedelta, optional, default None
|
||||
Select asof tolerance within this range; must be compatible
|
||||
with the merge index.
|
||||
allow_exact_matches : bool, default True
|
||||
|
||||
- If True, allow matching with the same 'on' value
|
||||
(i.e. less-than-or-equal-to / greater-than-or-equal-to)
|
||||
- If False, don't match the same 'on' value
|
||||
(i.e., strictly less-than / strictly greater-than).
|
||||
|
||||
direction : 'backward' (default), 'forward', or 'nearest'
|
||||
Whether to search for prior, subsequent, or closest matches.
|
||||
|
||||
Returns
|
||||
-------
|
||||
merged : DataFrame
|
||||
|
||||
See Also
|
||||
--------
|
||||
merge : Merge with a database-style join.
|
||||
merge_ordered : Merge with optional filling/interpolation.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> left = ps.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]})
|
||||
>>> left
|
||||
a left_val
|
||||
0 1 a
|
||||
1 5 b
|
||||
2 10 c
|
||||
|
||||
>>> right = ps.DataFrame({"a": [1, 2, 3, 6, 7], "right_val": [1, 2, 3, 6, 7]})
|
||||
>>> right
|
||||
a right_val
|
||||
0 1 1
|
||||
1 2 2
|
||||
2 3 3
|
||||
3 6 6
|
||||
4 7 7
|
||||
|
||||
>>> ps.merge_asof(left, right, on="a").sort_values("a").reset_index(drop=True)
|
||||
a left_val right_val
|
||||
0 1 a 1
|
||||
1 5 b 3
|
||||
2 10 c 7
|
||||
|
||||
>>> ps.merge_asof(
|
||||
... left,
|
||||
... right,
|
||||
... on="a",
|
||||
... allow_exact_matches=False
|
||||
... ).sort_values("a").reset_index(drop=True)
|
||||
a left_val right_val
|
||||
0 1 a NaN
|
||||
1 5 b 3.0
|
||||
2 10 c 7.0
|
||||
|
||||
>>> ps.merge_asof(
|
||||
... left,
|
||||
... right,
|
||||
... on="a",
|
||||
... direction="forward"
|
||||
... ).sort_values("a").reset_index(drop=True)
|
||||
a left_val right_val
|
||||
0 1 a 1.0
|
||||
1 5 b 6.0
|
||||
2 10 c NaN
|
||||
|
||||
>>> ps.merge_asof(
|
||||
... left,
|
||||
... right,
|
||||
... on="a",
|
||||
... direction="nearest"
|
||||
... ).sort_values("a").reset_index(drop=True)
|
||||
a left_val right_val
|
||||
0 1 a 1
|
||||
1 5 b 6
|
||||
2 10 c 7
|
||||
|
||||
We can use indexed DataFrames as well.
|
||||
|
||||
>>> left = ps.DataFrame({"left_val": ["a", "b", "c"]}, index=[1, 5, 10])
|
||||
>>> left
|
||||
left_val
|
||||
1 a
|
||||
5 b
|
||||
10 c
|
||||
|
||||
>>> right = ps.DataFrame({"right_val": [1, 2, 3, 6, 7]}, index=[1, 2, 3, 6, 7])
|
||||
>>> right
|
||||
right_val
|
||||
1 1
|
||||
2 2
|
||||
3 3
|
||||
6 6
|
||||
7 7
|
||||
|
||||
>>> ps.merge_asof(left, right, left_index=True, right_index=True).sort_index()
|
||||
left_val right_val
|
||||
1 a 1
|
||||
5 b 3
|
||||
10 c 7
|
||||
|
||||
Here is a real-world times-series example
|
||||
|
||||
>>> quotes = ps.DataFrame(
|
||||
... {
|
||||
... "time": [
|
||||
... pd.Timestamp("2016-05-25 13:30:00.023"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.023"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.030"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.041"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.048"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.049"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.072"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.075")
|
||||
... ],
|
||||
... "ticker": [
|
||||
... "GOOG",
|
||||
... "MSFT",
|
||||
... "MSFT",
|
||||
... "MSFT",
|
||||
... "GOOG",
|
||||
... "AAPL",
|
||||
... "GOOG",
|
||||
... "MSFT"
|
||||
... ],
|
||||
... "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01],
|
||||
... "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03]
|
||||
... }
|
||||
... )
|
||||
>>> quotes
|
||||
time ticker bid ask
|
||||
0 2016-05-25 13:30:00.023 GOOG 720.50 720.93
|
||||
1 2016-05-25 13:30:00.023 MSFT 51.95 51.96
|
||||
2 2016-05-25 13:30:00.030 MSFT 51.97 51.98
|
||||
3 2016-05-25 13:30:00.041 MSFT 51.99 52.00
|
||||
4 2016-05-25 13:30:00.048 GOOG 720.50 720.93
|
||||
5 2016-05-25 13:30:00.049 AAPL 97.99 98.01
|
||||
6 2016-05-25 13:30:00.072 GOOG 720.50 720.88
|
||||
7 2016-05-25 13:30:00.075 MSFT 52.01 52.03
|
||||
|
||||
>>> trades = ps.DataFrame(
|
||||
... {
|
||||
... "time": [
|
||||
... pd.Timestamp("2016-05-25 13:30:00.023"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.038"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.048"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.048"),
|
||||
... pd.Timestamp("2016-05-25 13:30:00.048")
|
||||
... ],
|
||||
... "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"],
|
||||
... "price": [51.95, 51.95, 720.77, 720.92, 98.0],
|
||||
... "quantity": [75, 155, 100, 100, 100]
|
||||
... }
|
||||
... )
|
||||
>>> trades
|
||||
time ticker price quantity
|
||||
0 2016-05-25 13:30:00.023 MSFT 51.95 75
|
||||
1 2016-05-25 13:30:00.038 MSFT 51.95 155
|
||||
2 2016-05-25 13:30:00.048 GOOG 720.77 100
|
||||
3 2016-05-25 13:30:00.048 GOOG 720.92 100
|
||||
4 2016-05-25 13:30:00.048 AAPL 98.00 100
|
||||
|
||||
By default we are taking the asof of the quotes
|
||||
|
||||
>>> ps.merge_asof(
|
||||
... trades, quotes, on="time", by="ticker"
|
||||
... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True)
|
||||
time ticker price quantity bid ask
|
||||
0 2016-05-25 13:30:00.023 MSFT 51.95 75 51.95 51.96
|
||||
1 2016-05-25 13:30:00.038 MSFT 51.95 155 51.97 51.98
|
||||
2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN
|
||||
3 2016-05-25 13:30:00.048 GOOG 720.77 100 720.50 720.93
|
||||
4 2016-05-25 13:30:00.048 GOOG 720.92 100 720.50 720.93
|
||||
|
||||
We only asof within 2ms between the quote time and the trade time
|
||||
|
||||
>>> ps.merge_asof(
|
||||
... trades,
|
||||
... quotes,
|
||||
... on="time",
|
||||
... by="ticker",
|
||||
... tolerance=F.expr("INTERVAL 2 MILLISECONDS") # pd.Timedelta("2ms")
|
||||
... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True)
|
||||
time ticker price quantity bid ask
|
||||
0 2016-05-25 13:30:00.023 MSFT 51.95 75 51.95 51.96
|
||||
1 2016-05-25 13:30:00.038 MSFT 51.95 155 NaN NaN
|
||||
2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN
|
||||
3 2016-05-25 13:30:00.048 GOOG 720.77 100 720.50 720.93
|
||||
4 2016-05-25 13:30:00.048 GOOG 720.92 100 720.50 720.93
|
||||
|
||||
We only asof within 10ms between the quote time and the trade time
|
||||
and we exclude exact matches on time. However *prior* data will
|
||||
propagate forward
|
||||
|
||||
>>> ps.merge_asof(
|
||||
... trades,
|
||||
... quotes,
|
||||
... on="time",
|
||||
... by="ticker",
|
||||
... tolerance=F.expr("INTERVAL 10 MILLISECONDS"), # pd.Timedelta("10ms")
|
||||
... allow_exact_matches=False
|
||||
... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True)
|
||||
time ticker price quantity bid ask
|
||||
0 2016-05-25 13:30:00.023 MSFT 51.95 75 NaN NaN
|
||||
1 2016-05-25 13:30:00.038 MSFT 51.95 155 51.97 51.98
|
||||
2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN
|
||||
3 2016-05-25 13:30:00.048 GOOG 720.77 100 NaN NaN
|
||||
4 2016-05-25 13:30:00.048 GOOG 720.92 100 NaN NaN
|
||||
"""
|
||||
|
||||
def to_list(os: Optional[Union[Name, List[Name]]]) -> List[Label]:
|
||||
if os is None:
|
||||
return []
|
||||
elif is_name_like_tuple(os):
|
||||
return [os] # type: ignore
|
||||
elif is_name_like_value(os):
|
||||
return [(os,)]
|
||||
else:
|
||||
return [o if is_name_like_tuple(o) else (o,) for o in os]
|
||||
|
||||
if isinstance(left, Series):
|
||||
left = left.to_frame()
|
||||
if isinstance(right, Series):
|
||||
right = right.to_frame()
|
||||
|
||||
if on:
|
||||
if left_on or right_on:
|
||||
raise ValueError(
|
||||
'Can only pass argument "on" OR "left_on" and "right_on", '
|
||||
"not a combination of both."
|
||||
)
|
||||
left_as_of_names = list(map(left._internal.spark_column_name_for, to_list(on)))
|
||||
right_as_of_names = list(map(right._internal.spark_column_name_for, to_list(on)))
|
||||
else:
|
||||
if left_index:
|
||||
if isinstance(left.index, MultiIndex):
|
||||
raise ValueError("left can only have one index")
|
||||
left_as_of_names = left._internal.index_spark_column_names
|
||||
else:
|
||||
left_as_of_names = list(map(left._internal.spark_column_name_for, to_list(left_on)))
|
||||
if right_index:
|
||||
if isinstance(right.index, MultiIndex):
|
||||
raise ValueError("right can only have one index")
|
||||
right_as_of_names = right._internal.index_spark_column_names
|
||||
else:
|
||||
right_as_of_names = list(map(right._internal.spark_column_name_for, to_list(right_on)))
|
||||
|
||||
if left_as_of_names and not right_as_of_names:
|
||||
raise ValueError("Must pass right_on or right_index=True")
|
||||
if right_as_of_names and not left_as_of_names:
|
||||
raise ValueError("Must pass left_on or left_index=True")
|
||||
if not left_as_of_names and not right_as_of_names:
|
||||
common = list(left.columns.intersection(right.columns))
|
||||
if len(common) == 0:
|
||||
raise ValueError(
|
||||
"No common columns to perform merge on. Merge options: "
|
||||
"left_on=None, right_on=None, left_index=False, right_index=False"
|
||||
)
|
||||
left_as_of_names = list(map(left._internal.spark_column_name_for, to_list(common)))
|
||||
right_as_of_names = list(map(right._internal.spark_column_name_for, to_list(common)))
|
||||
|
||||
if len(left_as_of_names) != 1:
|
||||
raise ValueError("can only asof on a key for left")
|
||||
if len(right_as_of_names) != 1:
|
||||
raise ValueError("can only asof on a key for right")
|
||||
|
||||
if by:
|
||||
if left_by or right_by:
|
||||
raise ValueError('Can only pass argument "on" OR "left_by" and "right_by".')
|
||||
left_join_on_names = list(map(left._internal.spark_column_name_for, to_list(by)))
|
||||
right_join_on_names = list(map(right._internal.spark_column_name_for, to_list(by)))
|
||||
else:
|
||||
left_join_on_names = list(map(left._internal.spark_column_name_for, to_list(left_by)))
|
||||
right_join_on_names = list(map(right._internal.spark_column_name_for, to_list(right_by)))
|
||||
|
||||
if left_join_on_names and not right_join_on_names:
|
||||
raise ValueError("missing right_by")
|
||||
if right_join_on_names and not left_join_on_names:
|
||||
raise ValueError("missing left_by")
|
||||
if len(left_join_on_names) != len(right_join_on_names):
|
||||
raise ValueError("left_by and right_by must be same length")
|
||||
|
||||
# We should distinguish the name to avoid ambiguous column name after merging.
|
||||
right_prefix = "__right_"
|
||||
right_as_of_names = [right_prefix + right_as_of_name for right_as_of_name in right_as_of_names]
|
||||
right_join_on_names = [
|
||||
right_prefix + right_join_on_name for right_join_on_name in right_join_on_names
|
||||
]
|
||||
|
||||
left_as_of_name = left_as_of_names[0]
|
||||
right_as_of_name = right_as_of_names[0]
|
||||
|
||||
def resolve(internal: InternalFrame, side: str) -> InternalFrame:
|
||||
rename = lambda col: "__{}_{}".format(side, col)
|
||||
internal = internal.resolved_copy
|
||||
sdf = internal.spark_frame
|
||||
sdf = sdf.select(
|
||||
*[
|
||||
scol_for(sdf, col).alias(rename(col))
|
||||
for col in sdf.columns
|
||||
if col not in HIDDEN_COLUMNS
|
||||
],
|
||||
*HIDDEN_COLUMNS
|
||||
)
|
||||
return internal.copy(
|
||||
spark_frame=sdf,
|
||||
index_spark_columns=[
|
||||
scol_for(sdf, rename(col)) for col in internal.index_spark_column_names
|
||||
],
|
||||
index_fields=[field.copy(name=rename(field.name)) for field in internal.index_fields],
|
||||
data_spark_columns=[
|
||||
scol_for(sdf, rename(col)) for col in internal.data_spark_column_names
|
||||
],
|
||||
data_fields=[field.copy(name=rename(field.name)) for field in internal.data_fields],
|
||||
)
|
||||
|
||||
left_internal = left._internal.resolved_copy
|
||||
right_internal = resolve(right._internal, "right")
|
||||
|
||||
left_table = left_internal.spark_frame.alias("left_table")
|
||||
right_table = right_internal.spark_frame.alias("right_table")
|
||||
|
||||
left_as_of_column = scol_for(left_table, left_as_of_name)
|
||||
right_as_of_column = scol_for(right_table, right_as_of_name)
|
||||
|
||||
if left_join_on_names:
|
||||
left_join_on_columns = [scol_for(left_table, label) for label in left_join_on_names]
|
||||
right_join_on_columns = [scol_for(right_table, label) for label in right_join_on_names]
|
||||
on = reduce(
|
||||
lambda l, r: l & r,
|
||||
[l == r for l, r in zip(left_join_on_columns, right_join_on_columns)],
|
||||
)
|
||||
else:
|
||||
on = None
|
||||
|
||||
if tolerance is not None and not isinstance(tolerance, Column):
|
||||
tolerance = SF.lit(tolerance)
|
||||
|
||||
as_of_joined_table = left_table._joinAsOf(
|
||||
right_table,
|
||||
leftAsOfColumn=left_as_of_column,
|
||||
rightAsOfColumn=right_as_of_column,
|
||||
on=on,
|
||||
how="left",
|
||||
tolerance=tolerance,
|
||||
allowExactMatches=allow_exact_matches,
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
# Unpack suffixes tuple for convenience
|
||||
left_suffix = suffixes[0]
|
||||
right_suffix = suffixes[1]
|
||||
|
||||
# Append suffixes to columns with the same name to avoid conflicts later
|
||||
duplicate_columns = set(left_internal.column_labels) & set(right_internal.column_labels)
|
||||
|
||||
exprs = []
|
||||
data_columns = []
|
||||
column_labels = []
|
||||
|
||||
left_scol_for = lambda label: scol_for(
|
||||
as_of_joined_table, left_internal.spark_column_name_for(label)
|
||||
)
|
||||
right_scol_for = lambda label: scol_for(
|
||||
as_of_joined_table, right_internal.spark_column_name_for(label)
|
||||
)
|
||||
|
||||
for label in left_internal.column_labels:
|
||||
col = left_internal.spark_column_name_for(label)
|
||||
scol = left_scol_for(label)
|
||||
if label in duplicate_columns:
|
||||
spark_column_name = left_internal.spark_column_name_for(label)
|
||||
if spark_column_name in (left_as_of_names + left_join_on_names) and (
|
||||
(right_prefix + spark_column_name) in (right_as_of_names + right_join_on_names)
|
||||
):
|
||||
pass
|
||||
else:
|
||||
col = col + left_suffix
|
||||
scol = scol.alias(col)
|
||||
label = tuple([str(label[0]) + left_suffix] + list(label[1:]))
|
||||
exprs.append(scol)
|
||||
data_columns.append(col)
|
||||
column_labels.append(label)
|
||||
for label in right_internal.column_labels:
|
||||
# recover `right_prefix` here.
|
||||
col = right_internal.spark_column_name_for(label)[len(right_prefix) :]
|
||||
scol = right_scol_for(label).alias(col)
|
||||
if label in duplicate_columns:
|
||||
spark_column_name = left_internal.spark_column_name_for(label)
|
||||
if spark_column_name in left_as_of_names + left_join_on_names and (
|
||||
(right_prefix + spark_column_name) in right_as_of_names + right_join_on_names
|
||||
):
|
||||
continue
|
||||
else:
|
||||
col = col + right_suffix
|
||||
scol = scol.alias(col)
|
||||
label = tuple([str(label[0]) + right_suffix] + list(label[1:]))
|
||||
exprs.append(scol)
|
||||
data_columns.append(col)
|
||||
column_labels.append(label)
|
||||
|
||||
# Retain indices if they are used for joining
|
||||
if left_index or right_index:
|
||||
index_spark_column_names = [
|
||||
SPARK_INDEX_NAME_FORMAT(i) for i in range(len(left_internal.index_spark_column_names))
|
||||
]
|
||||
left_index_scols = [
|
||||
scol.alias(name)
|
||||
for scol, name in zip(left_internal.index_spark_columns, index_spark_column_names)
|
||||
]
|
||||
exprs.extend(left_index_scols)
|
||||
index_names = left_internal.index_names
|
||||
else:
|
||||
index_spark_column_names = []
|
||||
index_names = []
|
||||
|
||||
selected_columns = as_of_joined_table.select(*exprs)
|
||||
|
||||
internal = InternalFrame(
|
||||
spark_frame=selected_columns,
|
||||
index_spark_columns=[scol_for(selected_columns, col) for col in index_spark_column_names],
|
||||
index_names=index_names,
|
||||
column_labels=column_labels,
|
||||
data_spark_columns=[scol_for(selected_columns, col) for col in data_columns],
|
||||
)
|
||||
return DataFrame(internal)
|
||||
|
||||
|
||||
@no_type_check
|
||||
def to_numeric(arg, errors="raise"):
|
||||
"""
|
||||
|
|
|
@ -24,6 +24,7 @@ import pandas as pd
|
|||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas.utils import name_like_string
|
||||
from pyspark.sql.utils import AnalysisException
|
||||
from pyspark.testing.pandasutils import PandasOnSparkTestCase
|
||||
|
||||
|
||||
|
@ -283,6 +284,143 @@ class ReshapeTest(PandasOnSparkTestCase):
|
|||
pd.get_dummies(pdf, columns=("x", 1), dtype=np.int8).rename(columns=name_like_string),
|
||||
)
|
||||
|
||||
def test_merge_asof(self):
|
||||
pdf_left = pd.DataFrame(
|
||||
{"a": [1, 5, 10], "b": ["x", "y", "z"], "left_val": ["a", "b", "c"]}, index=[10, 20, 30]
|
||||
)
|
||||
pdf_right = pd.DataFrame(
|
||||
{"a": [1, 2, 3, 6, 7], "b": ["v", "w", "x", "y", "z"], "right_val": [1, 2, 3, 6, 7]},
|
||||
index=[100, 101, 102, 103, 104],
|
||||
)
|
||||
psdf_left = ps.from_pandas(pdf_left)
|
||||
psdf_right = ps.from_pandas(pdf_right)
|
||||
|
||||
self.assert_eq(
|
||||
pd.merge_asof(pdf_left, pdf_right, on="a").sort_values("a").reset_index(drop=True),
|
||||
ps.merge_asof(psdf_left, psdf_right, on="a").sort_values("a").reset_index(drop=True),
|
||||
)
|
||||
self.assert_eq(
|
||||
(
|
||||
pd.merge_asof(pdf_left, pdf_right, left_on="a", right_on="a")
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
(
|
||||
ps.merge_asof(psdf_left, psdf_right, left_on="a", right_on="a")
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
)
|
||||
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
|
||||
self.assert_eq(
|
||||
pd.merge_asof(
|
||||
pdf_left.set_index("a"), pdf_right, left_index=True, right_on="a"
|
||||
).sort_index(),
|
||||
ps.merge_asof(
|
||||
psdf_left.set_index("a"), psdf_right, left_index=True, right_on="a"
|
||||
).sort_index(),
|
||||
)
|
||||
else:
|
||||
expected = pd.DataFrame(
|
||||
{
|
||||
"b_x": ["x", "y", "z"],
|
||||
"left_val": ["a", "b", "c"],
|
||||
"a": [1, 3, 7],
|
||||
"b_y": ["v", "x", "z"],
|
||||
"right_val": [1, 3, 7],
|
||||
},
|
||||
index=pd.Index([1, 5, 10], name="a"),
|
||||
)
|
||||
self.assert_eq(
|
||||
expected,
|
||||
ps.merge_asof(
|
||||
psdf_left.set_index("a"), psdf_right, left_index=True, right_on="a"
|
||||
).sort_index(),
|
||||
)
|
||||
self.assert_eq(
|
||||
pd.merge_asof(
|
||||
pdf_left, pdf_right.set_index("a"), left_on="a", right_index=True
|
||||
).sort_index(),
|
||||
ps.merge_asof(
|
||||
psdf_left, psdf_right.set_index("a"), left_on="a", right_index=True
|
||||
).sort_index(),
|
||||
)
|
||||
self.assert_eq(
|
||||
pd.merge_asof(
|
||||
pdf_left.set_index("a"), pdf_right.set_index("a"), left_index=True, right_index=True
|
||||
).sort_index(),
|
||||
ps.merge_asof(
|
||||
psdf_left.set_index("a"),
|
||||
psdf_right.set_index("a"),
|
||||
left_index=True,
|
||||
right_index=True,
|
||||
).sort_index(),
|
||||
)
|
||||
self.assert_eq(
|
||||
(
|
||||
pd.merge_asof(pdf_left, pdf_right, on="a", by="b")
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
(
|
||||
ps.merge_asof(psdf_left, psdf_right, on="a", by="b")
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
)
|
||||
self.assert_eq(
|
||||
(
|
||||
pd.merge_asof(pdf_left, pdf_right, on="a", tolerance=1)
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
(
|
||||
ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=1)
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
)
|
||||
self.assert_eq(
|
||||
(
|
||||
pd.merge_asof(pdf_left, pdf_right, on="a", allow_exact_matches=False)
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
(
|
||||
ps.merge_asof(psdf_left, psdf_right, on="a", allow_exact_matches=False)
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
)
|
||||
self.assert_eq(
|
||||
(
|
||||
pd.merge_asof(pdf_left, pdf_right, on="a", direction="forward")
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
(
|
||||
ps.merge_asof(psdf_left, psdf_right, on="a", direction="forward")
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
)
|
||||
self.assert_eq(
|
||||
(
|
||||
pd.merge_asof(pdf_left, pdf_right, on="a", direction="nearest")
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
(
|
||||
ps.merge_asof(psdf_left, psdf_right, on="a", direction="nearest")
|
||||
.sort_values("a")
|
||||
.reset_index(drop=True)
|
||||
),
|
||||
)
|
||||
|
||||
self.assertRaises(
|
||||
AnalysisException, lambda: ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=-1)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
|
@ -1357,6 +1357,123 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
|
|||
jdf = self._jdf.join(other._jdf, on, how)
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
# TODO(SPARK-22947): Fix the DataFrame API.
|
||||
def _joinAsOf(
|
||||
self,
|
||||
other,
|
||||
leftAsOfColumn,
|
||||
rightAsOfColumn,
|
||||
on=None,
|
||||
how=None,
|
||||
*,
|
||||
tolerance=None,
|
||||
allowExactMatches=True,
|
||||
direction="backward",
|
||||
):
|
||||
"""
|
||||
Perform an as-of join.
|
||||
|
||||
This is similar to a left-join except that we match on nearest
|
||||
key rather than equal keys.
|
||||
|
||||
.. versionadded:: 3.3.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
other : :class:`DataFrame`
|
||||
Right side of the join
|
||||
leftAsOfColumn : str or :class:`Column`
|
||||
a string for the as-of join column name, or a Column
|
||||
rightAsOfColumn : str or :class:`Column`
|
||||
a string for the as-of join column name, or a Column
|
||||
on : str, list or :class:`Column`, optional
|
||||
a string for the join column name, a list of column names,
|
||||
a join expression (Column), or a list of Columns.
|
||||
If `on` is a string or a list of strings indicating the name of the join column(s),
|
||||
the column(s) must exist on both sides, and this performs an equi-join.
|
||||
how : str, optional
|
||||
default ``inner``. Must be one of: ``inner`` and ``left``.
|
||||
tolerance : :class:`Column`, optional
|
||||
an asof tolerance within this range; must be compatible
|
||||
with the merge index.
|
||||
allowExactMatches : bool, optional
|
||||
default ``True``.
|
||||
direction : str, optional
|
||||
default ``backward``. Must be one of: ``backward``, ``forward``, and ``nearest``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
The following performs an as-of join between ``left`` and ``right``.
|
||||
|
||||
>>> left = spark.createDataFrame([(1, "a"), (5, "b"), (10, "c")], ["a", "left_val"])
|
||||
>>> right = spark.createDataFrame([(1, 1), (2, 2), (3, 3), (6, 6), (7, 7)],
|
||||
... ["a", "right_val"])
|
||||
>>> left._joinAsOf(
|
||||
... right, leftAsOfColumn="a", rightAsOfColumn="a"
|
||||
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
|
||||
[Row(a=1, left_val='a', right_val=1),
|
||||
Row(a=5, left_val='b', right_val=3),
|
||||
Row(a=10, left_val='c', right_val=7)]
|
||||
|
||||
>>> from pyspark.sql import functions as F
|
||||
>>> left._joinAsOf(
|
||||
... right, leftAsOfColumn="a", rightAsOfColumn="a", tolerance=F.lit(1)
|
||||
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
|
||||
[Row(a=1, left_val='a', right_val=1)]
|
||||
|
||||
>>> left._joinAsOf(
|
||||
... right, leftAsOfColumn="a", rightAsOfColumn="a", how="left", tolerance=F.lit(1)
|
||||
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
|
||||
[Row(a=1, left_val='a', right_val=1),
|
||||
Row(a=5, left_val='b', right_val=None),
|
||||
Row(a=10, left_val='c', right_val=None)]
|
||||
|
||||
>>> left._joinAsOf(
|
||||
... right, leftAsOfColumn="a", rightAsOfColumn="a", allowExactMatches=False
|
||||
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
|
||||
[Row(a=5, left_val='b', right_val=3),
|
||||
Row(a=10, left_val='c', right_val=7)]
|
||||
|
||||
>>> left._joinAsOf(
|
||||
... right, leftAsOfColumn="a", rightAsOfColumn="a", direction="forward"
|
||||
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
|
||||
[Row(a=1, left_val='a', right_val=1),
|
||||
Row(a=5, left_val='b', right_val=6)]
|
||||
"""
|
||||
if isinstance(leftAsOfColumn, str):
|
||||
leftAsOfColumn = self[leftAsOfColumn]
|
||||
left_as_of_jcol = leftAsOfColumn._jc
|
||||
if isinstance(rightAsOfColumn, str):
|
||||
rightAsOfColumn = other[rightAsOfColumn]
|
||||
right_as_of_jcol = rightAsOfColumn._jc
|
||||
|
||||
if on is not None and not isinstance(on, list):
|
||||
on = [on]
|
||||
|
||||
if on is not None:
|
||||
if isinstance(on[0], str):
|
||||
on = self._jseq(on)
|
||||
else:
|
||||
assert isinstance(on[0], Column), "on should be Column or list of Column"
|
||||
on = reduce(lambda x, y: x.__and__(y), on)
|
||||
on = on._jc
|
||||
|
||||
if how is None:
|
||||
how = "inner"
|
||||
assert isinstance(how, str), "how should be a string"
|
||||
|
||||
if tolerance is not None:
|
||||
assert isinstance(tolerance, Column), "tolerance should be Column"
|
||||
tolerance = tolerance._jc
|
||||
|
||||
jdf = self._jdf.joinAsOf(
|
||||
other._jdf,
|
||||
left_as_of_jcol, right_as_of_jcol,
|
||||
on,
|
||||
how, tolerance, allowExactMatches, direction
|
||||
)
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
def sortWithinPartitions(self, *cols, **kwargs):
|
||||
"""Returns a new :class:`DataFrame` with each partition sorted by the specified column(s).
|
||||
|
||||
|
|
|
@ -249,6 +249,20 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
|
|||
s"join condition '${condition.sql}' " +
|
||||
s"of type ${condition.dataType.catalogString} is not a boolean.")
|
||||
|
||||
case j @ AsOfJoin(_, _, _, Some(condition), _, _, _)
|
||||
if condition.dataType != BooleanType =>
|
||||
failAnalysis(
|
||||
s"join condition '${condition.sql}' " +
|
||||
s"of type ${condition.dataType.catalogString} is not a boolean.")
|
||||
|
||||
case j @ AsOfJoin(_, _, _, _, _, _, Some(toleranceAssertion)) =>
|
||||
if (!toleranceAssertion.foldable) {
|
||||
failAnalysis("Input argument tolerance must be a constant.")
|
||||
}
|
||||
if (!toleranceAssertion.eval().asInstanceOf[Boolean]) {
|
||||
failAnalysis("Input argument tolerance must be non-negative.")
|
||||
}
|
||||
|
||||
case a @ Aggregate(groupingExprs, aggregateExprs, child) =>
|
||||
def isAggregateExpression(expr: Expression): Boolean = {
|
||||
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
|
||||
|
@ -506,6 +520,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
|
|||
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|
||||
""".stripMargin)
|
||||
|
||||
case j: AsOfJoin if !j.duplicateResolved =>
|
||||
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
|
||||
failAnalysis(
|
||||
s"""
|
||||
|Failure when resolving conflicting references in AsOfJoin:
|
||||
|$plan
|
||||
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|
||||
|""".stripMargin)
|
||||
|
||||
// TODO: although map type is not orderable, technically map type should be able to be
|
||||
// used in equality comparison, remove this type check once we support it.
|
||||
case o if mapColumnInSetOperation(o).isDefined =>
|
||||
|
|
|
@ -41,7 +41,8 @@ case class ReferenceEqualPlanWrapper(plan: LogicalPlan) {
|
|||
object DeduplicateRelations extends Rule[LogicalPlan] {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
renewDuplicatedRelations(mutable.HashSet.empty, plan)._1.resolveOperatorsUpWithPruning(
|
||||
_.containsAnyPattern(JOIN, LATERAL_JOIN, INTERSECT, EXCEPT, UNION, COMMAND), ruleId) {
|
||||
_.containsAnyPattern(JOIN, LATERAL_JOIN, AS_OF_JOIN, INTERSECT, EXCEPT, UNION, COMMAND),
|
||||
ruleId) {
|
||||
case p: LogicalPlan if !p.childrenResolved => p
|
||||
// To resolve duplicate expression IDs for Join.
|
||||
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
|
||||
|
@ -49,6 +50,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
|
|||
// Resolve duplicate output for LateralJoin.
|
||||
case j @ LateralJoin(left, right, _, _) if right.resolved && !j.duplicateResolved =>
|
||||
j.copy(right = right.withNewPlan(dedupRight(left, right.plan)))
|
||||
// Resolve duplicate output for AsOfJoin.
|
||||
case j @ AsOfJoin(left, right, _, _, _, _, _) if !j.duplicateResolved =>
|
||||
j.copy(right = dedupRight(left, right))
|
||||
// intersect/except will be rewritten to join at the beginning of optimizer. Here we need to
|
||||
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
|
||||
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
|
||||
|
|
|
@ -159,7 +159,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
|
|||
PullOutGroupingExpressions,
|
||||
ComputeCurrentTime,
|
||||
ReplaceCurrentLike(catalogManager),
|
||||
SpecialDatetimeValues) ::
|
||||
SpecialDatetimeValues,
|
||||
RewriteAsOfJoin) ::
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Optimizer rules start here
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -282,7 +283,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
|
|||
RewritePredicateSubquery.ruleName ::
|
||||
NormalizeFloatingNumbers.ruleName ::
|
||||
ReplaceUpdateFieldsExpression.ruleName ::
|
||||
PullOutGroupingExpressions.ruleName :: Nil
|
||||
PullOutGroupingExpressions.ruleName ::
|
||||
RewriteAsOfJoin.ruleName :: Nil
|
||||
|
||||
/**
|
||||
* Optimize all the subqueries inside expression.
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules._
|
||||
import org.apache.spark.sql.catalyst.trees.TreePattern._
|
||||
|
||||
/**
|
||||
* Replaces logical [[AsOfJoin]] operator using a combination of Join and Aggregate operator.
|
||||
*
|
||||
* Input Pseudo-Query:
|
||||
* {{{
|
||||
* SELECT * FROM left ASOF JOIN right ON (condition, as_of on(left.t, right.t), tolerance)
|
||||
* }}}
|
||||
*
|
||||
* Rewritten Query:
|
||||
* {{{
|
||||
* SELECT left.*, __right__.*
|
||||
* FROM (
|
||||
* SELECT
|
||||
* left.*,
|
||||
* (
|
||||
* SELECT MIN_BY(STRUCT(right.*), left.t - right.t) AS __nearest_right__
|
||||
* FROM right
|
||||
* WHERE condition AND left.t >= right.t AND right.t >= left.t - tolerance
|
||||
* ) as __right__
|
||||
* FROM left
|
||||
* )
|
||||
* WHERE __right__ IS NOT NULL
|
||||
* }}}
|
||||
*/
|
||||
object RewriteAsOfJoin extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
|
||||
_.containsPattern(AS_OF_JOIN), ruleId) {
|
||||
case AsOfJoin(left, right, asOfCondition, condition, joinType, orderExpression, _) =>
|
||||
val conditionWithOuterReference =
|
||||
condition.map(And(_, asOfCondition)).getOrElse(asOfCondition).transformUp {
|
||||
case a: AttributeReference if left.outputSet.contains(a) =>
|
||||
OuterReference(a)
|
||||
}
|
||||
val filtered = Filter(conditionWithOuterReference, right)
|
||||
|
||||
val orderExpressionWithOuterReference = orderExpression.transformUp {
|
||||
case a: AttributeReference if left.outputSet.contains(a) =>
|
||||
OuterReference(a)
|
||||
}
|
||||
val rightStruct = CreateStruct(right.output)
|
||||
val nearestRight = MinBy(rightStruct, orderExpressionWithOuterReference)
|
||||
.toAggregateExpression()
|
||||
val aggExpr = Alias(nearestRight, "__nearest_right__")()
|
||||
val aggregate = Aggregate(Seq.empty, Seq(aggExpr), filtered)
|
||||
|
||||
val projectWithScalarSubquery = Project(
|
||||
left.output :+ Alias(ScalarSubquery(aggregate, left.output), "__right__")(),
|
||||
left)
|
||||
|
||||
val filterRight = joinType match {
|
||||
case LeftOuter => projectWithScalarSubquery
|
||||
case _ =>
|
||||
Filter(IsNotNull(projectWithScalarSubquery.output.last), projectWithScalarSubquery)
|
||||
}
|
||||
|
||||
Project(
|
||||
left.output ++ right.output.zipWithIndex.map {
|
||||
case (out, idx) =>
|
||||
Alias(GetStructField(filterRight.output.last, idx), out.name)(exprId = out.exprId)
|
||||
},
|
||||
filterRight)
|
||||
}
|
||||
}
|
|
@ -121,3 +121,24 @@ object LeftSemiOrAnti {
|
|||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
object AsOfJoinDirection {
|
||||
|
||||
def apply(direction: String): AsOfJoinDirection = {
|
||||
direction.toLowerCase(Locale.ROOT) match {
|
||||
case "forward" => Forward
|
||||
case "backward" => Backward
|
||||
case "nearest" => Nearest
|
||||
case _ =>
|
||||
val supported = Seq("forward", "backward", "nearest")
|
||||
throw new IllegalArgumentException(s"Unsupported as-of join direction '$direction'. " +
|
||||
"Supported as-of join direction include: " + supported.mkString("'", "', '", "'") + ".")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sealed abstract class AsOfJoinDirection
|
||||
|
||||
case object Forward extends AsOfJoinDirection
|
||||
case object Backward extends AsOfJoinDirection
|
||||
case object Nearest extends AsOfJoinDirection
|
||||
|
|
|
@ -1597,3 +1597,119 @@ case class LateralJoin(
|
|||
copy(left = newChild)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A logical plan for as-of join.
|
||||
*/
|
||||
case class AsOfJoin(
|
||||
left: LogicalPlan,
|
||||
right: LogicalPlan,
|
||||
asOfCondition: Expression,
|
||||
condition: Option[Expression],
|
||||
joinType: JoinType,
|
||||
orderExpression: Expression,
|
||||
toleranceAssertion: Option[Expression]) extends BinaryNode {
|
||||
|
||||
require(Seq(Inner, LeftOuter).contains(joinType),
|
||||
s"Unsupported as-of join type $joinType")
|
||||
|
||||
override protected def stringArgs: Iterator[Any] = super.stringArgs.take(5)
|
||||
|
||||
override def output: Seq[Attribute] = {
|
||||
joinType match {
|
||||
case LeftOuter =>
|
||||
left.output ++ right.output.map(_.withNullability(true))
|
||||
case _ =>
|
||||
left.output ++ right.output
|
||||
}
|
||||
}
|
||||
|
||||
def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
|
||||
|
||||
override lazy val resolved: Boolean = {
|
||||
childrenResolved &&
|
||||
expressions.forall(_.resolved) &&
|
||||
duplicateResolved &&
|
||||
asOfCondition.dataType == BooleanType &&
|
||||
condition.forall(_.dataType == BooleanType) &&
|
||||
toleranceAssertion.forall { assertion =>
|
||||
assertion.foldable && assertion.eval().asInstanceOf[Boolean]
|
||||
}
|
||||
}
|
||||
|
||||
final override val nodePatterns: Seq[TreePattern] = Seq(AS_OF_JOIN)
|
||||
|
||||
override protected def withNewChildrenInternal(
|
||||
newLeft: LogicalPlan, newRight: LogicalPlan): AsOfJoin = {
|
||||
copy(left = newLeft, right = newRight)
|
||||
}
|
||||
}
|
||||
|
||||
object AsOfJoin {
|
||||
|
||||
def apply(
|
||||
left: LogicalPlan,
|
||||
right: LogicalPlan,
|
||||
leftAsOf: Expression,
|
||||
rightAsOf: Expression,
|
||||
condition: Option[Expression],
|
||||
joinType: JoinType,
|
||||
tolerance: Option[Expression],
|
||||
allowExactMatches: Boolean,
|
||||
direction: AsOfJoinDirection): AsOfJoin = {
|
||||
val asOfCond = makeAsOfCond(leftAsOf, rightAsOf, tolerance, allowExactMatches, direction)
|
||||
val orderingExpr = makeOrderingExpr(leftAsOf, rightAsOf, direction)
|
||||
AsOfJoin(left, right, asOfCond, condition, joinType,
|
||||
orderingExpr, tolerance.map(t => GreaterThanOrEqual(t, Literal.default(t.dataType))))
|
||||
}
|
||||
|
||||
private def makeAsOfCond(
|
||||
leftAsOf: Expression,
|
||||
rightAsOf: Expression,
|
||||
tolerance: Option[Expression],
|
||||
allowExactMatches: Boolean,
|
||||
direction: AsOfJoinDirection): Expression = {
|
||||
val base = (allowExactMatches, direction) match {
|
||||
case (true, Backward) => GreaterThanOrEqual(leftAsOf, rightAsOf)
|
||||
case (false, Backward) => GreaterThan(leftAsOf, rightAsOf)
|
||||
case (true, Forward) => LessThanOrEqual(leftAsOf, rightAsOf)
|
||||
case (false, Forward) => LessThan(leftAsOf, rightAsOf)
|
||||
case (true, Nearest) => Literal.TrueLiteral
|
||||
case (false, Nearest) => Not(EqualTo(leftAsOf, rightAsOf))
|
||||
}
|
||||
tolerance match {
|
||||
case Some(tolerance) =>
|
||||
(allowExactMatches, direction) match {
|
||||
case (true, Backward) =>
|
||||
And(base, GreaterThanOrEqual(rightAsOf, Subtract(leftAsOf, tolerance)))
|
||||
case (false, Backward) =>
|
||||
And(base, GreaterThan(rightAsOf, Subtract(leftAsOf, tolerance)))
|
||||
case (true, Forward) =>
|
||||
And(base, LessThanOrEqual(rightAsOf, Add(leftAsOf, tolerance)))
|
||||
case (false, Forward) =>
|
||||
And(base, LessThan(rightAsOf, Add(leftAsOf, tolerance)))
|
||||
case (true, Nearest) =>
|
||||
And(GreaterThanOrEqual(rightAsOf, Subtract(leftAsOf, tolerance)),
|
||||
LessThanOrEqual(rightAsOf, Add(leftAsOf, tolerance)))
|
||||
case (false, Nearest) =>
|
||||
And(base,
|
||||
And(GreaterThan(rightAsOf, Subtract(leftAsOf, tolerance)),
|
||||
LessThan(rightAsOf, Add(leftAsOf, tolerance))))
|
||||
}
|
||||
case None => base
|
||||
}
|
||||
}
|
||||
|
||||
private def makeOrderingExpr(
|
||||
leftAsOf: Expression,
|
||||
rightAsOf: Expression,
|
||||
direction: AsOfJoinDirection): Expression = {
|
||||
direction match {
|
||||
case Backward => Subtract(leftAsOf, rightAsOf)
|
||||
case Forward => Subtract(rightAsOf, leftAsOf)
|
||||
case Nearest =>
|
||||
If(GreaterThan(leftAsOf, rightAsOf),
|
||||
Subtract(leftAsOf, rightAsOf), Subtract(rightAsOf, leftAsOf))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -144,6 +144,7 @@ object RuleIdCollection {
|
|||
"org.apache.spark.sql.catalyst.optimizer.ReplaceIntersectWithSemiJoin" ::
|
||||
"org.apache.spark.sql.catalyst.optimizer.RewriteExceptAll" ::
|
||||
"org.apache.spark.sql.catalyst.optimizer.RewriteIntersectAll" ::
|
||||
"org.apache.spark.sql.catalyst.optimizer.RewriteAsOfJoin" ::
|
||||
"org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" ::
|
||||
"org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" ::
|
||||
"org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" ::
|
||||
|
|
|
@ -91,6 +91,7 @@ object TreePattern extends Enumeration {
|
|||
|
||||
// Logical plan patterns (alphabetically ordered)
|
||||
val AGGREGATE: Value = Value
|
||||
val AS_OF_JOIN: Value = Value
|
||||
val COMMAND: Value = Value
|
||||
val CTE: Value = Value
|
||||
val DISTINCT_LIKE: Value = Value
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions.{CreateStruct, GetStructField, If, OuterReference, ScalarSubquery}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.MinBy
|
||||
import org.apache.spark.sql.catalyst.plans.{AsOfJoinDirection, Inner, LeftOuter, PlanTest}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{AsOfJoin, LocalRelation}
|
||||
|
||||
class RewriteAsOfJoinSuite extends PlanTest {
|
||||
|
||||
test("simple") {
|
||||
val left = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val right = LocalRelation('a.int, 'b.int, 'd.int)
|
||||
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
|
||||
tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("backward"))
|
||||
|
||||
val rewritten = RewriteAsOfJoin(query.analyze)
|
||||
|
||||
val filter = OuterReference(left.output(0)) >= right.output(0)
|
||||
val rightStruct = CreateStruct(right.output)
|
||||
val orderExpression = OuterReference(left.output(0)) - right.output(0)
|
||||
val nearestRight = MinBy(rightStruct, orderExpression)
|
||||
.toAggregateExpression().as("__nearest_right__")
|
||||
|
||||
val scalarSubquery = left.select(
|
||||
left.output :+ ScalarSubquery(
|
||||
right.where(filter).groupBy()(nearestRight),
|
||||
left.output).as("__right__"): _*)
|
||||
val correctAnswer = scalarSubquery
|
||||
.where(scalarSubquery.output.last.isNotNull)
|
||||
.select(left.output :+
|
||||
GetStructField(scalarSubquery.output.last, 0).as("a") :+
|
||||
GetStructField(scalarSubquery.output.last, 1).as("b") :+
|
||||
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
|
||||
|
||||
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
|
||||
}
|
||||
|
||||
test("condition") {
|
||||
val left = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val right = LocalRelation('a.int, 'b.int, 'd.int)
|
||||
val query = AsOfJoin(left, right, left.output(0), right.output(0),
|
||||
Some(left.output(1) === right.output(1)), Inner,
|
||||
tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("backward"))
|
||||
|
||||
val rewritten = RewriteAsOfJoin(query.analyze)
|
||||
|
||||
val filter = OuterReference(left.output(1)) === right.output(1) &&
|
||||
OuterReference(left.output(0)) >= right.output(0)
|
||||
val rightStruct = CreateStruct(right.output)
|
||||
val orderExpression = OuterReference(left.output(0)) - right.output(0)
|
||||
val nearestRight = MinBy(rightStruct, orderExpression)
|
||||
.toAggregateExpression().as("__nearest_right__")
|
||||
|
||||
val scalarSubquery = left.select(
|
||||
left.output :+ ScalarSubquery(
|
||||
right.where(filter).groupBy()(nearestRight),
|
||||
left.output).as("__right__"): _*)
|
||||
val correctAnswer = scalarSubquery
|
||||
.where(scalarSubquery.output.last.isNotNull)
|
||||
.select(left.output :+
|
||||
GetStructField(scalarSubquery.output.last, 0).as("a") :+
|
||||
GetStructField(scalarSubquery.output.last, 1).as("b") :+
|
||||
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
|
||||
|
||||
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
|
||||
}
|
||||
|
||||
test("left outer") {
|
||||
val left = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val right = LocalRelation('a.int, 'b.int, 'd.int)
|
||||
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
|
||||
tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("backward"))
|
||||
|
||||
val rewritten = RewriteAsOfJoin(query.analyze)
|
||||
|
||||
val filter = OuterReference(left.output(0)) >= right.output(0)
|
||||
val rightStruct = CreateStruct(right.output)
|
||||
val orderExpression = OuterReference(left.output(0)) - right.output(0)
|
||||
val nearestRight = MinBy(rightStruct, orderExpression)
|
||||
.toAggregateExpression().as("__nearest_right__")
|
||||
|
||||
val scalarSubquery = left.select(
|
||||
left.output :+ ScalarSubquery(
|
||||
right.where(filter).groupBy()(nearestRight),
|
||||
left.output).as("__right__"): _*)
|
||||
val correctAnswer = scalarSubquery
|
||||
.where(scalarSubquery.output.last.isNotNull)
|
||||
.select(left.output :+
|
||||
GetStructField(scalarSubquery.output.last, 0).as("a") :+
|
||||
GetStructField(scalarSubquery.output.last, 1).as("b") :+
|
||||
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
|
||||
|
||||
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
|
||||
}
|
||||
|
||||
test("tolerance") {
|
||||
val left = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val right = LocalRelation('a.int, 'b.int, 'd.int)
|
||||
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
|
||||
tolerance = Some(1), allowExactMatches = true, direction = AsOfJoinDirection("backward"))
|
||||
|
||||
val rewritten = RewriteAsOfJoin(query.analyze)
|
||||
|
||||
val filter = OuterReference(left.output(0)) >= right.output(0) &&
|
||||
right.output(0) >= OuterReference(left.output(0)) - 1
|
||||
val rightStruct = CreateStruct(right.output)
|
||||
val orderExpression = OuterReference(left.output(0)) - right.output(0)
|
||||
val nearestRight = MinBy(rightStruct, orderExpression)
|
||||
.toAggregateExpression().as("__nearest_right__")
|
||||
|
||||
val scalarSubquery = left.select(
|
||||
left.output :+ ScalarSubquery(
|
||||
right.where(filter).groupBy()(nearestRight),
|
||||
left.output).as("__right__"): _*)
|
||||
val correctAnswer = scalarSubquery
|
||||
.where(scalarSubquery.output.last.isNotNull)
|
||||
.select(left.output :+
|
||||
GetStructField(scalarSubquery.output.last, 0).as("a") :+
|
||||
GetStructField(scalarSubquery.output.last, 1).as("b") :+
|
||||
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
|
||||
|
||||
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
|
||||
}
|
||||
|
||||
test("allowExactMatches = false") {
|
||||
val left = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val right = LocalRelation('a.int, 'b.int, 'd.int)
|
||||
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, LeftOuter,
|
||||
tolerance = None, allowExactMatches = false, direction = AsOfJoinDirection("backward"))
|
||||
|
||||
val rewritten = RewriteAsOfJoin(query.analyze)
|
||||
|
||||
val filter = OuterReference(left.output(0)) > right.output(0)
|
||||
val rightStruct = CreateStruct(right.output)
|
||||
val orderExpression = OuterReference(left.output(0)) - right.output(0)
|
||||
val nearestRight = MinBy(rightStruct, orderExpression)
|
||||
.toAggregateExpression().as("__nearest_right__")
|
||||
|
||||
val scalarSubquery = left.select(
|
||||
left.output :+ ScalarSubquery(
|
||||
right.where(filter).groupBy()(nearestRight),
|
||||
left.output).as("__right__"): _*)
|
||||
val correctAnswer = scalarSubquery
|
||||
.select(left.output :+
|
||||
GetStructField(scalarSubquery.output.last, 0).as("a") :+
|
||||
GetStructField(scalarSubquery.output.last, 1).as("b") :+
|
||||
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
|
||||
|
||||
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
|
||||
}
|
||||
|
||||
test("tolerance & allowExactMatches = false") {
|
||||
val left = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val right = LocalRelation('a.int, 'b.int, 'd.int)
|
||||
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
|
||||
tolerance = Some(1), allowExactMatches = false, direction = AsOfJoinDirection("backward"))
|
||||
|
||||
val rewritten = RewriteAsOfJoin(query.analyze)
|
||||
|
||||
val filter = OuterReference(left.output(0)) > right.output(0) &&
|
||||
right.output(0) > OuterReference(left.output(0)) - 1
|
||||
val rightStruct = CreateStruct(right.output)
|
||||
val orderExpression = OuterReference(left.output(0)) - right.output(0)
|
||||
val nearestRight = MinBy(rightStruct, orderExpression)
|
||||
.toAggregateExpression().as("__nearest_right__")
|
||||
|
||||
val scalarSubquery = left.select(
|
||||
left.output :+ ScalarSubquery(
|
||||
right.where(filter).groupBy()(nearestRight),
|
||||
left.output).as("__right__"): _*)
|
||||
val correctAnswer = scalarSubquery
|
||||
.where(scalarSubquery.output.last.isNotNull)
|
||||
.select(left.output :+
|
||||
GetStructField(scalarSubquery.output.last, 0).as("a") :+
|
||||
GetStructField(scalarSubquery.output.last, 1).as("b") :+
|
||||
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
|
||||
|
||||
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
|
||||
}
|
||||
|
||||
test("direction = forward") {
|
||||
val left = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val right = LocalRelation('a.int, 'b.int, 'd.int)
|
||||
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
|
||||
tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("forward"))
|
||||
|
||||
val rewritten = RewriteAsOfJoin(query.analyze)
|
||||
|
||||
val filter = OuterReference(left.output(0)) <= right.output(0)
|
||||
val rightStruct = CreateStruct(right.output)
|
||||
val orderExpression = right.output(0) - OuterReference(left.output(0))
|
||||
val nearestRight = MinBy(rightStruct, orderExpression)
|
||||
.toAggregateExpression().as("__nearest_right__")
|
||||
|
||||
val scalarSubquery = left.select(
|
||||
left.output :+ ScalarSubquery(
|
||||
right.where(filter).groupBy()(nearestRight),
|
||||
left.output).as("__right__"): _*)
|
||||
val correctAnswer = scalarSubquery
|
||||
.where(scalarSubquery.output.last.isNotNull)
|
||||
.select(left.output :+
|
||||
GetStructField(scalarSubquery.output.last, 0).as("a") :+
|
||||
GetStructField(scalarSubquery.output.last, 1).as("b") :+
|
||||
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
|
||||
|
||||
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
|
||||
}
|
||||
|
||||
test("direction = nearest") {
|
||||
val left = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val right = LocalRelation('a.int, 'b.int, 'd.int)
|
||||
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
|
||||
tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("nearest"))
|
||||
|
||||
val rewritten = RewriteAsOfJoin(query.analyze)
|
||||
|
||||
val filter = true
|
||||
val rightStruct = CreateStruct(right.output)
|
||||
val orderExpression = If(OuterReference(left.output(0)) > right.output(0),
|
||||
OuterReference(left.output(0)) - right.output(0),
|
||||
right.output(0) - OuterReference(left.output(0)))
|
||||
val nearestRight = MinBy(rightStruct, orderExpression)
|
||||
.toAggregateExpression().as("__nearest_right__")
|
||||
|
||||
val scalarSubquery = left.select(
|
||||
left.output :+ ScalarSubquery(
|
||||
right.where(filter).groupBy()(nearestRight),
|
||||
left.output).as("__right__"): _*)
|
||||
val correctAnswer = scalarSubquery
|
||||
.where(scalarSubquery.output.last.isNotNull)
|
||||
.select(left.output :+
|
||||
GetStructField(scalarSubquery.output.last, 0).as("a") :+
|
||||
GetStructField(scalarSubquery.output.last, 1).as("b") :+
|
||||
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
|
||||
|
||||
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
|
||||
}
|
||||
|
||||
test("tolerance & allowExactMatches = false & direction = nearest") {
|
||||
val left = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val right = LocalRelation('a.int, 'b.int, 'd.int)
|
||||
val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner,
|
||||
tolerance = Some(1), allowExactMatches = false, direction = AsOfJoinDirection("nearest"))
|
||||
|
||||
val rewritten = RewriteAsOfJoin(query.analyze)
|
||||
|
||||
val filter = (!(OuterReference(left.output(0)) === right.output(0))) &&
|
||||
((right.output(0) > OuterReference(left.output(0)) - 1) &&
|
||||
(right.output(0) < OuterReference(left.output(0)) + 1))
|
||||
val rightStruct = CreateStruct(right.output)
|
||||
val orderExpression = If(OuterReference(left.output(0)) > right.output(0),
|
||||
OuterReference(left.output(0)) - right.output(0),
|
||||
right.output(0) - OuterReference(left.output(0)))
|
||||
val nearestRight = MinBy(rightStruct, orderExpression)
|
||||
.toAggregateExpression().as("__nearest_right__")
|
||||
|
||||
val scalarSubquery = left.select(
|
||||
left.output :+ ScalarSubquery(
|
||||
right.where(filter).groupBy()(nearestRight),
|
||||
left.output).as("__right__"): _*)
|
||||
val correctAnswer = scalarSubquery
|
||||
.where(scalarSubquery.output.last.isNotNull)
|
||||
.select(left.output :+
|
||||
GetStructField(scalarSubquery.output.last, 0).as("a") :+
|
||||
GetStructField(scalarSubquery.output.last, 1).as("b") :+
|
||||
GetStructField(scalarSubquery.output.last, 2).as("d"): _*)
|
||||
|
||||
comparePlans(rewritten, correctAnswer, checkAnalysis = false)
|
||||
}
|
||||
}
|
|
@ -1046,6 +1046,47 @@ class Dataset[T] private[sql](
|
|||
plan.copy(condition = cond)
|
||||
}
|
||||
|
||||
/**
|
||||
* find the trivially true predicates and automatically resolves them to both sides.
|
||||
*/
|
||||
private def resolveSelfJoinCondition(
|
||||
right: Dataset[_],
|
||||
joinExprs: Option[Column],
|
||||
joinType: String): Join = {
|
||||
// Note that in this function, we introduce a hack in the case of self-join to automatically
|
||||
// resolve ambiguous join conditions into ones that might make sense [SPARK-6231].
|
||||
// Consider this case: df.join(df, df("key") === df("key"))
|
||||
// Since df("key") === df("key") is a trivially true condition, this actually becomes a
|
||||
// cartesian join. However, most likely users expect to perform a self join using "key".
|
||||
// With that assumption, this hack turns the trivially true condition into equality on join
|
||||
// keys that are resolved to both sides.
|
||||
|
||||
// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
|
||||
// After the cloning, left and right side will have distinct expression ids.
|
||||
val plan = withPlan(
|
||||
Join(logicalPlan, right.logicalPlan,
|
||||
JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE))
|
||||
.queryExecution.analyzed.asInstanceOf[Join]
|
||||
|
||||
// If auto self join alias is disabled, return the plan.
|
||||
if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
|
||||
return plan
|
||||
}
|
||||
|
||||
// If left/right have no output set intersection, return the plan.
|
||||
val lanalyzed = this.queryExecution.analyzed
|
||||
val ranalyzed = right.queryExecution.analyzed
|
||||
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
|
||||
return plan
|
||||
}
|
||||
|
||||
// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
|
||||
// By the time we get here, since we have already run analysis, all attributes should've been
|
||||
// resolved and become AttributeReference.
|
||||
|
||||
resolveSelfJoinCondition(plan)
|
||||
}
|
||||
|
||||
/**
|
||||
* Join with another `DataFrame`, using the given join expression. The following performs
|
||||
* a full outer join between `df1` and `df2`.
|
||||
|
@ -1071,38 +1112,8 @@ class Dataset[T] private[sql](
|
|||
* @since 2.0.0
|
||||
*/
|
||||
def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = {
|
||||
// Note that in this function, we introduce a hack in the case of self-join to automatically
|
||||
// resolve ambiguous join conditions into ones that might make sense [SPARK-6231].
|
||||
// Consider this case: df.join(df, df("key") === df("key"))
|
||||
// Since df("key") === df("key") is a trivially true condition, this actually becomes a
|
||||
// cartesian join. However, most likely users expect to perform a self join using "key".
|
||||
// With that assumption, this hack turns the trivially true condition into equality on join
|
||||
// keys that are resolved to both sides.
|
||||
|
||||
// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
|
||||
// After the cloning, left and right side will have distinct expression ids.
|
||||
val plan = withPlan(
|
||||
Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE))
|
||||
.queryExecution.analyzed.asInstanceOf[Join]
|
||||
|
||||
// If auto self join alias is disabled, return the plan.
|
||||
if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
|
||||
return withPlan(plan)
|
||||
}
|
||||
|
||||
// If left/right have no output set intersection, return the plan.
|
||||
val lanalyzed = this.queryExecution.analyzed
|
||||
val ranalyzed = right.queryExecution.analyzed
|
||||
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
|
||||
return withPlan(plan)
|
||||
}
|
||||
|
||||
// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
|
||||
// By the time we get here, since we have already run analysis, all attributes should've been
|
||||
// resolved and become AttributeReference.
|
||||
|
||||
withPlan {
|
||||
resolveSelfJoinCondition(plan)
|
||||
resolveSelfJoinCondition(right, Some(joinExprs), joinType)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1232,6 +1243,58 @@ class Dataset[T] private[sql](
|
|||
joinWith(other, condition, "inner")
|
||||
}
|
||||
|
||||
// TODO(SPARK-22947): Fix the DataFrame API.
|
||||
private[sql] def joinAsOf(
|
||||
other: Dataset[_],
|
||||
leftAsOf: Column,
|
||||
rightAsOf: Column,
|
||||
usingColumns: Seq[String],
|
||||
joinType: String,
|
||||
tolerance: Column,
|
||||
allowExactMatches: Boolean,
|
||||
direction: String): DataFrame = {
|
||||
val joinExprs = usingColumns.map { column =>
|
||||
EqualTo(resolve(column), other.resolve(column))
|
||||
}.reduceOption(And).map(Column.apply).orNull
|
||||
|
||||
joinAsOf(other, leftAsOf, rightAsOf, joinExprs, joinType,
|
||||
tolerance, allowExactMatches, direction)
|
||||
}
|
||||
|
||||
// TODO(SPARK-22947): Fix the DataFrame API.
|
||||
private[sql] def joinAsOf(
|
||||
other: Dataset[_],
|
||||
leftAsOf: Column,
|
||||
rightAsOf: Column,
|
||||
joinExprs: Column,
|
||||
joinType: String,
|
||||
tolerance: Column,
|
||||
allowExactMatches: Boolean,
|
||||
direction: String): DataFrame = {
|
||||
val joined = resolveSelfJoinCondition(other, Option(joinExprs), joinType)
|
||||
val leftAsOfExpr = leftAsOf.expr.transformUp {
|
||||
case a: AttributeReference if logicalPlan.outputSet.contains(a) =>
|
||||
val index = logicalPlan.output.indexWhere(_.exprId == a.exprId)
|
||||
joined.left.output(index)
|
||||
}
|
||||
val rightAsOfExpr = rightAsOf.expr.transformUp {
|
||||
case a: AttributeReference if other.logicalPlan.outputSet.contains(a) =>
|
||||
val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId)
|
||||
joined.right.output(index)
|
||||
}
|
||||
withPlan {
|
||||
AsOfJoin(
|
||||
joined.left, joined.right,
|
||||
leftAsOfExpr, rightAsOfExpr,
|
||||
joined.condition,
|
||||
joined.joinType,
|
||||
Option(tolerance).map(_.expr),
|
||||
allowExactMatches,
|
||||
AsOfJoinDirection(direction)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new Dataset with each partition sorted by the given expressions.
|
||||
*
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class DataFrameAsOfJoinSuite extends QueryTest
|
||||
with SharedSparkSession
|
||||
with AdaptiveSparkPlanHelper {
|
||||
|
||||
def prepareForAsOfJoin(): (DataFrame, DataFrame) = {
|
||||
val schema1 = StructType(
|
||||
StructField("a", IntegerType, false) ::
|
||||
StructField("b", StringType, false) ::
|
||||
StructField("left_val", StringType, false) :: Nil)
|
||||
val rowSeq1: List[Row] = List(Row(1, "x", "a"), Row(5, "y", "b"), Row(10, "z", "c"))
|
||||
val df1 = spark.createDataFrame(rowSeq1.asJava, schema1)
|
||||
|
||||
val schema2 = StructType(
|
||||
StructField("a", IntegerType) ::
|
||||
StructField("b", StringType) ::
|
||||
StructField("right_val", IntegerType) :: Nil)
|
||||
val rowSeq2: List[Row] = List(Row(1, "v", 1), Row(2, "w", 2), Row(3, "x", 3),
|
||||
Row(6, "y", 6), Row(7, "z", 7))
|
||||
val df2 = spark.createDataFrame(rowSeq2.asJava, schema2)
|
||||
|
||||
(df1, df2)
|
||||
}
|
||||
|
||||
test("as-of join - simple") {
|
||||
val (df1, df2) = prepareForAsOfJoin()
|
||||
checkAnswer(
|
||||
df1.joinAsOf(
|
||||
df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
|
||||
joinType = "inner", tolerance = null, allowExactMatches = true, direction = "backward"),
|
||||
Seq(
|
||||
Row(1, "x", "a", 1, "v", 1),
|
||||
Row(5, "y", "b", 3, "x", 3),
|
||||
Row(10, "z", "c", 7, "z", 7)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
test("as-of join - usingColumns") {
|
||||
val (df1, df2) = prepareForAsOfJoin()
|
||||
checkAnswer(
|
||||
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"),
|
||||
joinType = "inner", tolerance = null, allowExactMatches = true, direction = "backward"),
|
||||
Seq(
|
||||
Row(10, "z", "c", 7, "z", 7)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
test("as-of join - usingColumns, left outer") {
|
||||
val (df1, df2) = prepareForAsOfJoin()
|
||||
checkAnswer(
|
||||
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"),
|
||||
joinType = "left", tolerance = null, allowExactMatches = true, direction = "backward"),
|
||||
Seq(
|
||||
Row(1, "x", "a", null, null, null),
|
||||
Row(5, "y", "b", null, null, null),
|
||||
Row(10, "z", "c", 7, "z", 7)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
test("as-of join - tolerance = 1") {
|
||||
val (df1, df2) = prepareForAsOfJoin()
|
||||
checkAnswer(
|
||||
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
|
||||
joinType = "inner", tolerance = lit(1), allowExactMatches = true, direction = "backward"),
|
||||
Seq(
|
||||
Row(1, "x", "a", 1, "v", 1)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
test("as-of join - tolerance should be a constant") {
|
||||
val (df1, df2) = prepareForAsOfJoin()
|
||||
val errMsg = intercept[AnalysisException] {
|
||||
df1.joinAsOf(
|
||||
df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
|
||||
joinType = "inner", tolerance = df1.col("b"), allowExactMatches = true,
|
||||
direction = "backward")
|
||||
}.getMessage
|
||||
assert(errMsg.contains("Input argument tolerance must be a constant."))
|
||||
}
|
||||
|
||||
test("as-of join - tolerance should be non-negative") {
|
||||
val (df1, df2) = prepareForAsOfJoin()
|
||||
val errMsg = intercept[AnalysisException] {
|
||||
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
|
||||
joinType = "inner", tolerance = lit(-1), allowExactMatches = true, direction = "backward")
|
||||
}.getMessage
|
||||
assert(errMsg.contains("Input argument tolerance must be non-negative."))
|
||||
}
|
||||
|
||||
test("as-of join - allowExactMatches = false") {
|
||||
val (df1, df2) = prepareForAsOfJoin()
|
||||
checkAnswer(
|
||||
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
|
||||
joinType = "inner", tolerance = null, allowExactMatches = false, direction = "backward"),
|
||||
Seq(
|
||||
Row(5, "y", "b", 3, "x", 3),
|
||||
Row(10, "z", "c", 7, "z", 7)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
test("as-of join - direction = \"forward\"") {
|
||||
val (df1, df2) = prepareForAsOfJoin()
|
||||
checkAnswer(
|
||||
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
|
||||
joinType = "inner", tolerance = null, allowExactMatches = true, direction = "forward"),
|
||||
Seq(
|
||||
Row(1, "x", "a", 1, "v", 1),
|
||||
Row(5, "y", "b", 6, "y", 6)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
test("as-of join - direction = \"nearest\"") {
|
||||
val (df1, df2) = prepareForAsOfJoin()
|
||||
checkAnswer(
|
||||
df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
|
||||
joinType = "inner", tolerance = null, allowExactMatches = true, direction = "nearest"),
|
||||
Seq(
|
||||
Row(1, "x", "a", 1, "v", 1),
|
||||
Row(5, "y", "b", 6, "y", 6),
|
||||
Row(10, "z", "c", 7, "z", 7)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
test("as-of join - self") {
|
||||
val (df1, _) = prepareForAsOfJoin()
|
||||
checkAnswer(
|
||||
df1.joinAsOf(
|
||||
df1, df1.col("a"), df1.col("a"), usingColumns = Seq.empty,
|
||||
joinType = "left", tolerance = null, allowExactMatches = false, direction = "nearest"),
|
||||
Seq(
|
||||
Row(1, "x", "a", 5, "y", "b"),
|
||||
Row(5, "y", "b", 1, "x", "a"),
|
||||
Row(10, "z", "c", 5, "y", "b")
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue