[SPARK-35052][SQL] Use static bits for AttributeReference and Literal

### What changes were proposed in this pull request?

- Share a static ImmutableBitSet for `treePatternBits` in all object instances of AttributeReference.
- Share three static ImmutableBitSets for  `treePatternBits` in three kinds of Literals.
- Add an ImmutableBitSet as a subclass of BitSet.

### Why are the changes needed?

Reduce the additional memory usage caused by `treePatternBits`.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes #32157 from sigmod/leaf.

Authored-by: Yingyi Bu <yingyi.bu@databricks.com>
Signed-off-by: Gengliang Wang <ltnwgl@gmail.com>
This commit is contained in:
Yingyi Bu 2021-04-20 13:13:16 +08:00 committed by Gengliang Wang
parent bad4b6f025
commit f4926d1c8b
7 changed files with 245 additions and 4 deletions

View file

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.collection
private object ErrorMessage {
final val msg: String = "mutable operation is not supported"
}
// An immutable BitSet that initializes set bits in its constructor.
class ImmutableBitSet(val numBits: Int, val bitsToSet: Int*) extends BitSet(numBits) {
// Initialize the set bits.
{
val bitsIterator = bitsToSet.iterator
while (bitsIterator.hasNext) {
super.set(bitsIterator.next)
}
}
override def clear(): Unit = {
throw new UnsupportedOperationException(ErrorMessage.msg)
}
override def clearUntil(bitIndex: Int): Unit = {
throw new UnsupportedOperationException(ErrorMessage.msg)
}
override def set(index: Int): Unit = {
throw new UnsupportedOperationException(ErrorMessage.msg)
}
override def setUntil(bitIndex: Int): Unit = {
throw new UnsupportedOperationException(ErrorMessage.msg)
}
override def unset(index: Int): Unit = {
throw new UnsupportedOperationException(ErrorMessage.msg)
}
override def union(other: BitSet): Unit = {
throw new UnsupportedOperationException(ErrorMessage.msg)
}
}

View file

@ -0,0 +1,140 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.collection
import org.apache.spark.SparkFunSuite
class ImmutableBitSetSuite extends SparkFunSuite {
test("basic get") {
val bitset = new ImmutableBitSet(100, 0, 9, 1, 10, 90, 96)
val setBits = Seq(0, 9, 1, 10, 90, 96)
for (i <- 0 until 100) {
if (setBits.contains(i)) {
assert(bitset.get(i))
} else {
assert(!bitset.get(i))
}
}
assert(bitset.cardinality() === setBits.size)
}
test("nextSetBit") {
val bitset = new ImmutableBitSet(100, 0, 9, 1, 10, 90, 96)
assert(bitset.nextSetBit(0) === 0)
assert(bitset.nextSetBit(1) === 1)
assert(bitset.nextSetBit(2) === 9)
assert(bitset.nextSetBit(9) === 9)
assert(bitset.nextSetBit(10) === 10)
assert(bitset.nextSetBit(11) === 90)
assert(bitset.nextSetBit(80) === 90)
assert(bitset.nextSetBit(91) === 96)
assert(bitset.nextSetBit(96) === 96)
assert(bitset.nextSetBit(97) === -1)
}
test( "xor len(bitsetX) < len(bitsetY)" ) {
val bitsetX = new ImmutableBitSet(60, 0, 2, 3, 37, 41)
val bitsetY = new ImmutableBitSet(100, 0, 1, 3, 37, 38, 41, 85)
val bitsetXor = bitsetX ^ bitsetY
assert(bitsetXor.nextSetBit(0) === 1)
assert(bitsetXor.nextSetBit(1) === 1)
assert(bitsetXor.nextSetBit(2) === 2)
assert(bitsetXor.nextSetBit(3) === 38)
assert(bitsetXor.nextSetBit(38) === 38)
assert(bitsetXor.nextSetBit(39) === 85)
assert(bitsetXor.nextSetBit(42) === 85)
assert(bitsetXor.nextSetBit(85) === 85)
assert(bitsetXor.nextSetBit(86) === -1)
}
test( "xor len(bitsetX) > len(bitsetY)" ) {
val bitsetX = new ImmutableBitSet(100, 0, 1, 3, 37, 38, 41, 85)
val bitsetY = new ImmutableBitSet(60, 0, 2, 3, 37, 41)
val bitsetXor = bitsetX ^ bitsetY
assert(bitsetXor.nextSetBit(0) === 1)
assert(bitsetXor.nextSetBit(1) === 1)
assert(bitsetXor.nextSetBit(2) === 2)
assert(bitsetXor.nextSetBit(3) === 38)
assert(bitsetXor.nextSetBit(38) === 38)
assert(bitsetXor.nextSetBit(39) === 85)
assert(bitsetXor.nextSetBit(42) === 85)
assert(bitsetXor.nextSetBit(85) === 85)
assert(bitsetXor.nextSetBit(86) === -1)
}
test( "andNot len(bitsetX) < len(bitsetY)" ) {
val bitsetX = new ImmutableBitSet(60, 0, 2, 3, 37, 41, 48)
val bitsetY = new ImmutableBitSet(100, 0, 1, 3, 37, 38, 41, 85)
val bitsetDiff = bitsetX.andNot( bitsetY )
assert(bitsetDiff.nextSetBit(0) === 2)
assert(bitsetDiff.nextSetBit(1) === 2)
assert(bitsetDiff.nextSetBit(2) === 2)
assert(bitsetDiff.nextSetBit(3) === 48)
assert(bitsetDiff.nextSetBit(48) === 48)
assert(bitsetDiff.nextSetBit(49) === -1)
assert(bitsetDiff.nextSetBit(65) === -1)
}
test( "andNot len(bitsetX) > len(bitsetY)" ) {
val bitsetX = new ImmutableBitSet(100, 0, 1, 3, 37, 38, 41, 85)
val bitsetY = new ImmutableBitSet(60, 0, 2, 3, 37, 41, 48)
val bitsetDiff = bitsetX.andNot( bitsetY )
assert(bitsetDiff.nextSetBit(0) === 1)
assert(bitsetDiff.nextSetBit(1) === 1)
assert(bitsetDiff.nextSetBit(2) === 38)
assert(bitsetDiff.nextSetBit(3) === 38)
assert(bitsetDiff.nextSetBit(38) === 38)
assert(bitsetDiff.nextSetBit(39) === 85)
assert(bitsetDiff.nextSetBit(85) === 85)
assert(bitsetDiff.nextSetBit(86) === -1)
}
test( "immutability" ) {
val bitset = new ImmutableBitSet(100)
intercept[UnsupportedOperationException] {
bitset.set(1)
}
intercept[UnsupportedOperationException] {
bitset.setUntil(10)
}
intercept[UnsupportedOperationException] {
bitset.unset(1)
}
intercept[UnsupportedOperationException] {
bitset.clear()
}
intercept[UnsupportedOperationException] {
bitset.clearUntil(10)
}
intercept[UnsupportedOperationException] {
bitset.union(new ImmutableBitSet(100))
}
}
}

View file

@ -41,6 +41,8 @@ import org.json4s.JsonAST._
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern.{LITERAL, NULL_LITERAL, TRUE_OR_FALSE_LITERAL}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
@ -50,6 +52,8 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.BitSet
import org.apache.spark.util.collection.ImmutableBitSet
object Literal {
val TrueLiteral: Literal = Literal(true, BooleanType)
@ -296,6 +300,18 @@ object DecimalLiteral {
def smallerThanSmallestLong(v: Decimal): Boolean = v < Decimal(Long.MinValue)
}
object LiteralTreeBits {
// Singleton tree pattern BitSet for all Literals that are not true, false, or null.
val literalBits: BitSet = new ImmutableBitSet(TreePattern.maxId, LITERAL.id)
// Singleton tree pattern BitSet for all Literals that are true or false.
val booleanLiteralBits: BitSet = new ImmutableBitSet(
TreePattern.maxId, LITERAL.id, TRUE_OR_FALSE_LITERAL.id)
// Singleton tree pattern BitSet for all Literals that are nulls.
val nullLiteralBits: BitSet = new ImmutableBitSet(TreePattern.maxId, LITERAL.id, NULL_LITERAL.id)
}
/**
* In order to do type checking, use Literal.create() instead of constructor
*/
@ -308,6 +324,14 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
private def timeZoneId = DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)
override lazy val treePatternBits: BitSet = {
value match {
case null => LiteralTreeBits.nullLiteralBits
case true | false => LiteralTreeBits.booleanLiteralBits
case _ => LiteralTreeBits.literalBits
}
}
override def toString: String = value match {
case null => "null"
case binary: Array[Byte] => s"0x" + DatatypeConverter.printHexBinary(binary)

View file

@ -23,9 +23,13 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern.ATTRIBUTE_REFERENCE
import org.apache.spark.sql.catalyst.util.quoteIfNeeded
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.BitSet
import org.apache.spark.util.collection.ImmutableBitSet
object NamedExpression {
private val curId = new java.util.concurrent.atomic.AtomicLong()
@ -231,6 +235,11 @@ case class Alias(child: Expression, name: String)(
copy(child = newChild)(exprId, qualifier, explicitMetadata, nonInheritableMetadataKeys)
}
// Singleton tree pattern BitSet for all AttributeReference instances.
object AttributeReferenceTreeBits {
val bits: BitSet = new ImmutableBitSet(TreePattern.maxId, ATTRIBUTE_REFERENCE.id)
}
/**
* A reference to an attribute produced by another operator in the tree.
*
@ -253,6 +262,8 @@ case class AttributeReference(
val qualifier: Seq[String] = Seq.empty[String])
extends Attribute with Unevaluable {
override lazy val treePatternBits: BitSet = AttributeReferenceTreeBits.bits
/**
* Returns true iff the expression id is the same for both attributes.
*/

View file

@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter,
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, InsertStarAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_LITERAL, TRUE_OR_FALSE_LITERAL}
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.util.Utils
@ -49,7 +50,8 @@ import org.apache.spark.util.Utils
*/
object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL), ruleId) {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
@ -59,7 +61,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
mergeCondition = replaceNullWithFalse(mergeCond),
matchedActions = replaceNullWithFalse(matchedActions),
notMatchedActions = replaceNullWithFalse(notMatchedActions))
case p: LogicalPlan => p transformExpressions {
case p: LogicalPlan => p.transformExpressionsWithPruning(
_.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL), ruleId) {
// For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no
// difference, as `null <=> true` and `false <=> true` both return false.
case EqualNullSafe(left, TrueLiteral) =>

View file

@ -51,7 +51,8 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" ::
"org.apache.spark.sql.catalyst.optimizer.ReorderJoin" :: Nil
"org.apache.spark.sql.catalyst.optimizer.ReorderJoin" ::
"org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" :: Nil
}
// Maps rule names to ids. Rule ids are continuous natural numbers starting from 0.

View file

@ -23,8 +23,12 @@ object TreePattern extends Enumeration {
// Enum Ids start from 0.
// Expression patterns (alphabetically ordered)
val EXPRESSION_WITH_RANDOM_SEED = Value(0)
val ATTRIBUTE_REFERENCE = Value(0)
val EXPRESSION_WITH_RANDOM_SEED = Value
val IN: Value = Value
val LITERAL: Value = Value
val NULL_LITERAL: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
// Logical plan patterns (alphabetically ordered)