bd28e8e179
### What changes were proposed in this pull request? Currently the behavior of getting output and generating null checks in `FilterExec` is different. Thus some nullable attribute could be treated as not nullable by mistake. In `FilterExec.ouput`, an attribute is marked as nullable or not by finding its `exprId` in notNullAttributes: ``` a.nullable && notNullAttributes.contains(a.exprId) ``` But in `FilterExec.doConsume`, a `nullCheck` is generated or not for a predicate is decided by whether there is semantic equal not null predicate: ``` val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} if (idx != -1 && !generatedIsNotNullChecks(idx)) { generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. genPredicate(notNullPreds(idx), input, child.output) } else { "" } }.mkString("\n").trim ``` NPE will happen when run the SQL below: ``` sql("create table table1(x string)") sql("create table table2(x bigint)") sql("create table table3(x string)") sql("insert into table2 select null as x") sql( """ |select t1.x |from ( | select x from table1) t1 |left join ( | select x from ( | select x from table2 | union all | select substr(x,5) x from table3 | ) a | where length(x)>0 |) t3 |on t1.x=t3.x """.stripMargin).collect() ``` NPE Exception: ``` java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(generated.java:40) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:726) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:135) at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:94) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52) at org.apache.spark.scheduler.Task.run(Task.scala:127) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:449) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:452) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` the generated code: ``` == Subtree 4 / 5 == *(2) Project [cast(x#7L as string) AS x#9] +- *(2) Filter ((length(cast(x#7L as string)) > 0) AND isnotnull(cast(x#7L as string))) +- Scan hive default.table2 [x#7L], HiveTableRelation `default`.`table2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#7L] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage2(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=2 /* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator inputadapter_input_0; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 011 */ /* 012 */ public GeneratedIteratorForCodegenStage2(Object[] references) { /* 013 */ this.references = references; /* 014 */ } /* 015 */ /* 016 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 017 */ partitionIndex = index; /* 018 */ this.inputs = inputs; /* 019 */ inputadapter_input_0 = inputs[0]; /* 020 */ filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 021 */ filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 022 */ /* 023 */ } /* 024 */ /* 025 */ protected void processNext() throws java.io.IOException { /* 026 */ while ( inputadapter_input_0.hasNext()) { /* 027 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next(); /* 028 */ /* 029 */ do { /* 030 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0); /* 031 */ long inputadapter_value_0 = inputadapter_isNull_0 ? /* 032 */ -1L : (inputadapter_row_0.getLong(0)); /* 033 */ /* 034 */ boolean filter_isNull_2 = inputadapter_isNull_0; /* 035 */ UTF8String filter_value_2 = null; /* 036 */ if (!inputadapter_isNull_0) { /* 037 */ filter_value_2 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 038 */ } /* 039 */ int filter_value_1 = -1; /* 040 */ filter_value_1 = (filter_value_2).numChars(); /* 041 */ /* 042 */ boolean filter_value_0 = false; /* 043 */ filter_value_0 = filter_value_1 > 0; /* 044 */ if (!filter_value_0) continue; /* 045 */ /* 046 */ boolean filter_isNull_6 = inputadapter_isNull_0; /* 047 */ UTF8String filter_value_6 = null; /* 048 */ if (!inputadapter_isNull_0) { /* 049 */ filter_value_6 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 050 */ } /* 051 */ if (!(!filter_isNull_6)) continue; /* 052 */ /* 053 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 054 */ /* 055 */ boolean project_isNull_0 = false; /* 056 */ UTF8String project_value_0 = null; /* 057 */ if (!false) { /* 058 */ project_value_0 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 059 */ } /* 060 */ filter_mutableStateArray_0[1].reset(); /* 061 */ /* 062 */ filter_mutableStateArray_0[1].zeroOutNullBytes(); /* 063 */ /* 064 */ if (project_isNull_0) { /* 065 */ filter_mutableStateArray_0[1].setNullAt(0); /* 066 */ } else { /* 067 */ filter_mutableStateArray_0[1].write(0, project_value_0); /* 068 */ } /* 069 */ append((filter_mutableStateArray_0[1].getRow())); /* 070 */ /* 071 */ } while(false); /* 072 */ if (shouldStop()) return; /* 073 */ } /* 074 */ } /* 075 */ /* 076 */ } ``` This PR proposes to use semantic comparison both in `FilterExec.output` and `FilterExec.doConsume` for nullable attribute. With this PR, the generated code snippet is below: ``` == Subtree 2 / 5 == *(3) Project [substring(x#8, 5, 2147483647) AS x#5] +- *(3) Filter ((length(substring(x#8, 5, 2147483647)) > 0) AND isnotnull(substring(x#8, 5, 2147483647))) +- Scan hive default.table3 [x#8], HiveTableRelation `default`.`table3`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#8] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage3(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=3 /* 006 */ final class GeneratedIteratorForCodegenStage3 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator inputadapter_input_0; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 011 */ /* 012 */ public GeneratedIteratorForCodegenStage3(Object[] references) { /* 013 */ this.references = references; /* 014 */ } /* 015 */ /* 016 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 017 */ partitionIndex = index; /* 018 */ this.inputs = inputs; /* 019 */ inputadapter_input_0 = inputs[0]; /* 020 */ filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 021 */ filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 022 */ /* 023 */ } /* 024 */ /* 025 */ protected void processNext() throws java.io.IOException { /* 026 */ while ( inputadapter_input_0.hasNext()) { /* 027 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next(); /* 028 */ /* 029 */ do { /* 030 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0); /* 031 */ UTF8String inputadapter_value_0 = inputadapter_isNull_0 ? /* 032 */ null : (inputadapter_row_0.getUTF8String(0)); /* 033 */ /* 034 */ boolean filter_isNull_0 = true; /* 035 */ boolean filter_value_0 = false; /* 036 */ boolean filter_isNull_2 = true; /* 037 */ UTF8String filter_value_2 = null; /* 038 */ /* 039 */ if (!inputadapter_isNull_0) { /* 040 */ filter_isNull_2 = false; // resultCode could change nullability. /* 041 */ filter_value_2 = inputadapter_value_0.substringSQL(5, 2147483647); /* 042 */ /* 043 */ } /* 044 */ boolean filter_isNull_1 = filter_isNull_2; /* 045 */ int filter_value_1 = -1; /* 046 */ /* 047 */ if (!filter_isNull_2) { /* 048 */ filter_value_1 = (filter_value_2).numChars(); /* 049 */ } /* 050 */ if (!filter_isNull_1) { /* 051 */ filter_isNull_0 = false; // resultCode could change nullability. /* 052 */ filter_value_0 = filter_value_1 > 0; /* 053 */ /* 054 */ } /* 055 */ if (filter_isNull_0 || !filter_value_0) continue; /* 056 */ boolean filter_isNull_8 = true; /* 057 */ UTF8String filter_value_8 = null; /* 058 */ /* 059 */ if (!inputadapter_isNull_0) { /* 060 */ filter_isNull_8 = false; // resultCode could change nullability. /* 061 */ filter_value_8 = inputadapter_value_0.substringSQL(5, 2147483647); /* 062 */ /* 063 */ } /* 064 */ if (!(!filter_isNull_8)) continue; /* 065 */ /* 066 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 067 */ /* 068 */ boolean project_isNull_0 = true; /* 069 */ UTF8String project_value_0 = null; /* 070 */ /* 071 */ if (!inputadapter_isNull_0) { /* 072 */ project_isNull_0 = false; // resultCode could change nullability. /* 073 */ project_value_0 = inputadapter_value_0.substringSQL(5, 2147483647); /* 074 */ /* 075 */ } /* 076 */ filter_mutableStateArray_0[1].reset(); /* 077 */ /* 078 */ filter_mutableStateArray_0[1].zeroOutNullBytes(); /* 079 */ /* 080 */ if (project_isNull_0) { /* 081 */ filter_mutableStateArray_0[1].setNullAt(0); /* 082 */ } else { /* 083 */ filter_mutableStateArray_0[1].write(0, project_value_0); /* 084 */ } /* 085 */ append((filter_mutableStateArray_0[1].getRow())); /* 086 */ /* 087 */ } while(false); /* 088 */ if (shouldStop()) return; /* 089 */ } /* 090 */ } /* 091 */ /* 092 */ } ``` ### Why are the changes needed? Fix NPE bug in FilterExec. ### Does this PR introduce any user-facing change? no ### How was this patch tested? new UT Closes #25902 from wangshuo128/filter-codegen-npe. Authored-by: Wang Shuo <wangshuo128@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> |
||
---|---|---|
.. | ||
benchmarks | ||
src | ||
v1.2.1/src | ||
v2.3.5/src | ||
pom.xml |