spark-instrumented-optimizer/sql/core
linzebing e768eaa908 [SPARK-34707][SQL] Code-gen broadcast nested loop join (left outer/right outer)
### What changes were proposed in this pull request?

This PR is to add code-gen support for left outer (build right) and right outer (build left). Reference: `BroadcastNestedLoopJoinExec.codegenInner()` and `BroadcastNestedLoopJoinExec.outerJoin()`

### Why are the changes needed?

Improve query CPU performance.
Tested with a simple query:
```scala
val N = 20 << 20
val M = 1 << 4

val dim = broadcast(spark.range(M).selectExpr("id as k2"))
codegenBenchmark("left outer broadcast nested loop join", N) {
   val df = spark.range(N).selectExpr(s"id as k1").join(
     dim, col("k1") + 1 <= col("k2"), "left_outer")
   assert(df.queryExecution.sparkPlan.find(
     _.isInstanceOf[BroadcastNestedLoopJoinExec]).isDefined)
   df.noop()
}
```
Seeing 2x run time improvement:
```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7
Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
left outer broadcast nested loop join:                Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------------------
left outer broadcast nested loop join wholestage off           3024           3698         953          6.9         144.2       1.0X
left outer broadcast nested loop join wholestage on            1512           1659         172         13.9          72.1       2.0X
```

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Changed existing unit tests in `OuterJoinSuite` to cover codegen use cases.
Added unit test in WholeStageCodegenSuite.scala to make sure code-gen for broadcast nested loop join is taking effect, and test for multiple join case as well.

Example query:
```scala
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
df1.join(df2, $"k1" + 1 <= $"k2", "left_outer").explain("codegen")
```
Example generated code (`bnlj_doConsume_0` method):
```java
== Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:210(0.32% used); numInnerClasses:0) ==
*(2) BroadcastNestedLoopJoin BuildRight, LeftOuter, ((k1#2L + 1) <= k2#6L)
:- *(2) Project [id#0L AS k1#2L]
:  +- *(2) Range (0, 4, step=1, splits=16)
+- BroadcastExchange IdentityBroadcastMode, [id=#22]
   +- *(1) Project [id#4L AS k2#6L]
      +- *(1) Range (0, 3, step=1, splits=16)

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean range_initRange_0;
/* 010 */   private long range_nextIndex_0;
/* 011 */   private TaskContext range_taskContext_0;
/* 012 */   private InputMetrics range_inputMetrics_0;
/* 013 */   private long range_batchEnd_0;
/* 014 */   private long range_numElementsTodo_0;
/* 015 */   private InternalRow[] bnlj_buildRowArray_0;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4];
/* 017 */
/* 018 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */
/* 026 */     range_taskContext_0 = TaskContext.get();
/* 027 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 031 */     bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value();
/* 032 */     range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 033 */
/* 034 */   }
/* 035 */
/* 036 */   private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException {
/* 037 */     boolean bnlj_foundMatch_0 = false;
/* 038 */     for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) {
/* 039 */       UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0];
/* 040 */       boolean bnlj_shouldOutputRow_0 = false;
/* 041 */
/* 042 */       boolean bnlj_isNull_2 = true;
/* 043 */       long bnlj_value_2 = -1L;
/* 044 */       if (bnlj_buildRow_0 != null) {
/* 045 */         long bnlj_value_1 = bnlj_buildRow_0.getLong(0);
/* 046 */         bnlj_isNull_2 = false;
/* 047 */         bnlj_value_2 = bnlj_value_1;
/* 048 */       }
/* 049 */
/* 050 */       long bnlj_value_4 = -1L;
/* 051 */
/* 052 */       bnlj_value_4 = bnlj_expr_0_0 + 1L;
/* 053 */
/* 054 */       boolean bnlj_value_3 = false;
/* 055 */       bnlj_value_3 = bnlj_value_4 <= bnlj_value_2;
/* 056 */       if (!(false || !bnlj_value_3))
/* 057 */       {
/* 058 */         bnlj_shouldOutputRow_0 = true;
/* 059 */         bnlj_foundMatch_0 = true;
/* 060 */       }
/* 061 */       if (bnlj_arrayIndex_0 == bnlj_buildRowArray_0.length - 1 && !bnlj_foundMatch_0) {
/* 062 */         bnlj_buildRow_0 = null;
/* 063 */         bnlj_shouldOutputRow_0 = true;
/* 064 */       }
/* 065 */       if (bnlj_shouldOutputRow_0) {
/* 066 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 067 */
/* 068 */         boolean bnlj_isNull_9 = true;
/* 069 */         long bnlj_value_9 = -1L;
/* 070 */         if (bnlj_buildRow_0 != null) {
/* 071 */           long bnlj_value_8 = bnlj_buildRow_0.getLong(0);
/* 072 */           bnlj_isNull_9 = false;
/* 073 */           bnlj_value_9 = bnlj_value_8;
/* 074 */         }
/* 075 */         range_mutableStateArray_0[3].reset();
/* 076 */
/* 077 */         range_mutableStateArray_0[3].zeroOutNullBytes();
/* 078 */
/* 079 */         range_mutableStateArray_0[3].write(0, bnlj_expr_0_0);
/* 080 */
/* 081 */         if (bnlj_isNull_9) {
/* 082 */           range_mutableStateArray_0[3].setNullAt(1);
/* 083 */         } else {
/* 084 */           range_mutableStateArray_0[3].write(1, bnlj_value_9);
/* 085 */         }
/* 086 */         append((range_mutableStateArray_0[3].getRow()).copy());
/* 087 */
/* 088 */       }
/* 089 */     }
/* 090 */
/* 091 */   }
/* 092 */
/* 093 */   private void initRange(int idx) {
/* 094 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 095 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(16L);
/* 096 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L);
/* 097 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 098 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 099 */     long partitionEnd;
/* 100 */
/* 101 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 102 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 103 */       range_nextIndex_0 = Long.MAX_VALUE;
/* 104 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 105 */       range_nextIndex_0 = Long.MIN_VALUE;
/* 106 */     } else {
/* 107 */       range_nextIndex_0 = st.longValue();
/* 108 */     }
/* 109 */     range_batchEnd_0 = range_nextIndex_0;
/* 110 */
/* 111 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 112 */     .multiply(step).add(start);
/* 113 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 114 */       partitionEnd = Long.MAX_VALUE;
/* 115 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 116 */       partitionEnd = Long.MIN_VALUE;
/* 117 */     } else {
/* 118 */       partitionEnd = end.longValue();
/* 119 */     }
/* 120 */
/* 121 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 122 */       java.math.BigInteger.valueOf(range_nextIndex_0));
/* 123 */     range_numElementsTodo_0  = startToEnd.divide(step).longValue();
/* 124 */     if (range_numElementsTodo_0 < 0) {
/* 125 */       range_numElementsTodo_0 = 0;
/* 126 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 127 */       range_numElementsTodo_0++;
/* 128 */     }
/* 129 */   }
/* 130 */
/* 131 */   protected void processNext() throws java.io.IOException {
/* 132 */     // initialize Range
/* 133 */     if (!range_initRange_0) {
/* 134 */       range_initRange_0 = true;
/* 135 */       initRange(partitionIndex);
/* 136 */     }
/* 137 */
/* 138 */     while (true) {
/* 139 */       if (range_nextIndex_0 == range_batchEnd_0) {
/* 140 */         long range_nextBatchTodo_0;
/* 141 */         if (range_numElementsTodo_0 > 1000L) {
/* 142 */           range_nextBatchTodo_0 = 1000L;
/* 143 */           range_numElementsTodo_0 -= 1000L;
/* 144 */         } else {
/* 145 */           range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 146 */           range_numElementsTodo_0 = 0;
/* 147 */           if (range_nextBatchTodo_0 == 0) break;
/* 148 */         }
/* 149 */         range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 150 */       }
/* 151 */
/* 152 */       int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 153 */       for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) {
/* 154 */         long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 155 */
/* 156 */         // common sub-expressions
/* 157 */
/* 158 */         bnlj_doConsume_0(range_value_0);
/* 159 */
/* 160 */         if (shouldStop()) {
/* 161 */           range_nextIndex_0 = range_value_0 + 1L;
/* 162 */           ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
/* 163 */           range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1);
/* 164 */           return;
/* 165 */         }
/* 166 */
/* 167 */       }
/* 168 */       range_nextIndex_0 = range_batchEnd_0;
/* 169 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 170 */       range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 171 */       range_taskContext_0.killTaskIfInterrupted();
/* 172 */     }
/* 173 */   }
/* 174 */
/* 175 */ }
```

Closes #31931 from linzebing/code-left-right-outer.

Authored-by: linzebing <linzebing1995@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
2021-03-23 07:11:57 +00:00
..
benchmarks [SPARK-34815][SQL] Update CSVBenchmark 2021-03-22 10:49:53 +03:00
src [SPARK-34707][SQL] Code-gen broadcast nested loop join (left outer/right outer) 2021-03-23 07:11:57 +00:00
pom.xml [SPARK-33662][BUILD] Setting version to 3.2.0-SNAPSHOT 2020-12-04 14:10:42 -08:00