e768eaa908
### What changes were proposed in this pull request? This PR is to add code-gen support for left outer (build right) and right outer (build left). Reference: `BroadcastNestedLoopJoinExec.codegenInner()` and `BroadcastNestedLoopJoinExec.outerJoin()` ### Why are the changes needed? Improve query CPU performance. Tested with a simple query: ```scala val N = 20 << 20 val M = 1 << 4 val dim = broadcast(spark.range(M).selectExpr("id as k2")) codegenBenchmark("left outer broadcast nested loop join", N) { val df = spark.range(N).selectExpr(s"id as k1").join( dim, col("k1") + 1 <= col("k2"), "left_outer") assert(df.queryExecution.sparkPlan.find( _.isInstanceOf[BroadcastNestedLoopJoinExec]).isDefined) df.noop() } ``` Seeing 2x run time improvement: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7 Intel(R) Core(TM) i9-9980HK CPU 2.40GHz left outer broadcast nested loop join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------ left outer broadcast nested loop join wholestage off 3024 3698 953 6.9 144.2 1.0X left outer broadcast nested loop join wholestage on 1512 1659 172 13.9 72.1 2.0X ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Changed existing unit tests in `OuterJoinSuite` to cover codegen use cases. Added unit test in WholeStageCodegenSuite.scala to make sure code-gen for broadcast nested loop join is taking effect, and test for multiple join case as well. Example query: ```scala val df1 = spark.range(4).select($"id".as("k1")) val df2 = spark.range(3).select($"id".as("k2")) df1.join(df2, $"k1" + 1 <= $"k2", "left_outer").explain("codegen") ``` Example generated code (`bnlj_doConsume_0` method): ```java == Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:210(0.32% used); numInnerClasses:0) == *(2) BroadcastNestedLoopJoin BuildRight, LeftOuter, ((k1#2L + 1) <= k2#6L) :- *(2) Project [id#0L AS k1#2L] : +- *(2) Range (0, 4, step=1, splits=16) +- BroadcastExchange IdentityBroadcastMode, [id=#22] +- *(1) Project [id#4L AS k2#6L] +- *(1) Range (0, 3, step=1, splits=16) Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage2(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=2 /* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private boolean range_initRange_0; /* 010 */ private long range_nextIndex_0; /* 011 */ private TaskContext range_taskContext_0; /* 012 */ private InputMetrics range_inputMetrics_0; /* 013 */ private long range_batchEnd_0; /* 014 */ private long range_numElementsTodo_0; /* 015 */ private InternalRow[] bnlj_buildRowArray_0; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4]; /* 017 */ /* 018 */ public GeneratedIteratorForCodegenStage2(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ /* 026 */ range_taskContext_0 = TaskContext.get(); /* 027 */ range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics(); /* 028 */ range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 029 */ range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 030 */ range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 031 */ bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value(); /* 032 */ range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0); /* 033 */ /* 034 */ } /* 035 */ /* 036 */ private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException { /* 037 */ boolean bnlj_foundMatch_0 = false; /* 038 */ for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) { /* 039 */ UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0]; /* 040 */ boolean bnlj_shouldOutputRow_0 = false; /* 041 */ /* 042 */ boolean bnlj_isNull_2 = true; /* 043 */ long bnlj_value_2 = -1L; /* 044 */ if (bnlj_buildRow_0 != null) { /* 045 */ long bnlj_value_1 = bnlj_buildRow_0.getLong(0); /* 046 */ bnlj_isNull_2 = false; /* 047 */ bnlj_value_2 = bnlj_value_1; /* 048 */ } /* 049 */ /* 050 */ long bnlj_value_4 = -1L; /* 051 */ /* 052 */ bnlj_value_4 = bnlj_expr_0_0 + 1L; /* 053 */ /* 054 */ boolean bnlj_value_3 = false; /* 055 */ bnlj_value_3 = bnlj_value_4 <= bnlj_value_2; /* 056 */ if (!(false || !bnlj_value_3)) /* 057 */ { /* 058 */ bnlj_shouldOutputRow_0 = true; /* 059 */ bnlj_foundMatch_0 = true; /* 060 */ } /* 061 */ if (bnlj_arrayIndex_0 == bnlj_buildRowArray_0.length - 1 && !bnlj_foundMatch_0) { /* 062 */ bnlj_buildRow_0 = null; /* 063 */ bnlj_shouldOutputRow_0 = true; /* 064 */ } /* 065 */ if (bnlj_shouldOutputRow_0) { /* 066 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1); /* 067 */ /* 068 */ boolean bnlj_isNull_9 = true; /* 069 */ long bnlj_value_9 = -1L; /* 070 */ if (bnlj_buildRow_0 != null) { /* 071 */ long bnlj_value_8 = bnlj_buildRow_0.getLong(0); /* 072 */ bnlj_isNull_9 = false; /* 073 */ bnlj_value_9 = bnlj_value_8; /* 074 */ } /* 075 */ range_mutableStateArray_0[3].reset(); /* 076 */ /* 077 */ range_mutableStateArray_0[3].zeroOutNullBytes(); /* 078 */ /* 079 */ range_mutableStateArray_0[3].write(0, bnlj_expr_0_0); /* 080 */ /* 081 */ if (bnlj_isNull_9) { /* 082 */ range_mutableStateArray_0[3].setNullAt(1); /* 083 */ } else { /* 084 */ range_mutableStateArray_0[3].write(1, bnlj_value_9); /* 085 */ } /* 086 */ append((range_mutableStateArray_0[3].getRow()).copy()); /* 087 */ /* 088 */ } /* 089 */ } /* 090 */ /* 091 */ } /* 092 */ /* 093 */ private void initRange(int idx) { /* 094 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 095 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(16L); /* 096 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L); /* 097 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 098 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 099 */ long partitionEnd; /* 100 */ /* 101 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 102 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 103 */ range_nextIndex_0 = Long.MAX_VALUE; /* 104 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 105 */ range_nextIndex_0 = Long.MIN_VALUE; /* 106 */ } else { /* 107 */ range_nextIndex_0 = st.longValue(); /* 108 */ } /* 109 */ range_batchEnd_0 = range_nextIndex_0; /* 110 */ /* 111 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 112 */ .multiply(step).add(start); /* 113 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 114 */ partitionEnd = Long.MAX_VALUE; /* 115 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 116 */ partitionEnd = Long.MIN_VALUE; /* 117 */ } else { /* 118 */ partitionEnd = end.longValue(); /* 119 */ } /* 120 */ /* 121 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 122 */ java.math.BigInteger.valueOf(range_nextIndex_0)); /* 123 */ range_numElementsTodo_0 = startToEnd.divide(step).longValue(); /* 124 */ if (range_numElementsTodo_0 < 0) { /* 125 */ range_numElementsTodo_0 = 0; /* 126 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 127 */ range_numElementsTodo_0++; /* 128 */ } /* 129 */ } /* 130 */ /* 131 */ protected void processNext() throws java.io.IOException { /* 132 */ // initialize Range /* 133 */ if (!range_initRange_0) { /* 134 */ range_initRange_0 = true; /* 135 */ initRange(partitionIndex); /* 136 */ } /* 137 */ /* 138 */ while (true) { /* 139 */ if (range_nextIndex_0 == range_batchEnd_0) { /* 140 */ long range_nextBatchTodo_0; /* 141 */ if (range_numElementsTodo_0 > 1000L) { /* 142 */ range_nextBatchTodo_0 = 1000L; /* 143 */ range_numElementsTodo_0 -= 1000L; /* 144 */ } else { /* 145 */ range_nextBatchTodo_0 = range_numElementsTodo_0; /* 146 */ range_numElementsTodo_0 = 0; /* 147 */ if (range_nextBatchTodo_0 == 0) break; /* 148 */ } /* 149 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L; /* 150 */ } /* 151 */ /* 152 */ int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L); /* 153 */ for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) { /* 154 */ long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0; /* 155 */ /* 156 */ // common sub-expressions /* 157 */ /* 158 */ bnlj_doConsume_0(range_value_0); /* 159 */ /* 160 */ if (shouldStop()) { /* 161 */ range_nextIndex_0 = range_value_0 + 1L; /* 162 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1); /* 163 */ range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1); /* 164 */ return; /* 165 */ } /* 166 */ /* 167 */ } /* 168 */ range_nextIndex_0 = range_batchEnd_0; /* 169 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0); /* 170 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0); /* 171 */ range_taskContext_0.killTaskIfInterrupted(); /* 172 */ } /* 173 */ } /* 174 */ /* 175 */ } ``` Closes #31931 from linzebing/code-left-right-outer. Authored-by: linzebing <linzebing1995@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> |
||
---|---|---|
.. | ||
benchmarks | ||
src | ||
pom.xml |