[SPARK-35052][SQL] Use static bits for AttributeReference and Literal
### What changes were proposed in this pull request? - Share a static ImmutableBitSet for `treePatternBits` in all object instances of AttributeReference. - Share three static ImmutableBitSets for `treePatternBits` in three kinds of Literals. - Add an ImmutableBitSet as a subclass of BitSet. ### Why are the changes needed? Reduce the additional memory usage caused by `treePatternBits`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #32157 from sigmod/leaf. Authored-by: Yingyi Bu <yingyi.bu@databricks.com> Signed-off-by: Gengliang Wang <ltnwgl@gmail.com>
This commit is contained in:
parent
bad4b6f025
commit
f4926d1c8b
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util.collection
|
||||
|
||||
private object ErrorMessage {
|
||||
final val msg: String = "mutable operation is not supported"
|
||||
}
|
||||
|
||||
// An immutable BitSet that initializes set bits in its constructor.
|
||||
class ImmutableBitSet(val numBits: Int, val bitsToSet: Int*) extends BitSet(numBits) {
|
||||
|
||||
// Initialize the set bits.
|
||||
{
|
||||
val bitsIterator = bitsToSet.iterator
|
||||
while (bitsIterator.hasNext) {
|
||||
super.set(bitsIterator.next)
|
||||
}
|
||||
}
|
||||
|
||||
override def clear(): Unit = {
|
||||
throw new UnsupportedOperationException(ErrorMessage.msg)
|
||||
}
|
||||
|
||||
override def clearUntil(bitIndex: Int): Unit = {
|
||||
throw new UnsupportedOperationException(ErrorMessage.msg)
|
||||
}
|
||||
|
||||
override def set(index: Int): Unit = {
|
||||
throw new UnsupportedOperationException(ErrorMessage.msg)
|
||||
}
|
||||
|
||||
override def setUntil(bitIndex: Int): Unit = {
|
||||
throw new UnsupportedOperationException(ErrorMessage.msg)
|
||||
}
|
||||
|
||||
override def unset(index: Int): Unit = {
|
||||
throw new UnsupportedOperationException(ErrorMessage.msg)
|
||||
}
|
||||
|
||||
override def union(other: BitSet): Unit = {
|
||||
throw new UnsupportedOperationException(ErrorMessage.msg)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util.collection
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
|
||||
class ImmutableBitSetSuite extends SparkFunSuite {
|
||||
|
||||
test("basic get") {
|
||||
val bitset = new ImmutableBitSet(100, 0, 9, 1, 10, 90, 96)
|
||||
val setBits = Seq(0, 9, 1, 10, 90, 96)
|
||||
for (i <- 0 until 100) {
|
||||
if (setBits.contains(i)) {
|
||||
assert(bitset.get(i))
|
||||
} else {
|
||||
assert(!bitset.get(i))
|
||||
}
|
||||
}
|
||||
assert(bitset.cardinality() === setBits.size)
|
||||
}
|
||||
|
||||
test("nextSetBit") {
|
||||
val bitset = new ImmutableBitSet(100, 0, 9, 1, 10, 90, 96)
|
||||
|
||||
assert(bitset.nextSetBit(0) === 0)
|
||||
assert(bitset.nextSetBit(1) === 1)
|
||||
assert(bitset.nextSetBit(2) === 9)
|
||||
assert(bitset.nextSetBit(9) === 9)
|
||||
assert(bitset.nextSetBit(10) === 10)
|
||||
assert(bitset.nextSetBit(11) === 90)
|
||||
assert(bitset.nextSetBit(80) === 90)
|
||||
assert(bitset.nextSetBit(91) === 96)
|
||||
assert(bitset.nextSetBit(96) === 96)
|
||||
assert(bitset.nextSetBit(97) === -1)
|
||||
}
|
||||
|
||||
test( "xor len(bitsetX) < len(bitsetY)" ) {
|
||||
val bitsetX = new ImmutableBitSet(60, 0, 2, 3, 37, 41)
|
||||
val bitsetY = new ImmutableBitSet(100, 0, 1, 3, 37, 38, 41, 85)
|
||||
|
||||
val bitsetXor = bitsetX ^ bitsetY
|
||||
|
||||
assert(bitsetXor.nextSetBit(0) === 1)
|
||||
assert(bitsetXor.nextSetBit(1) === 1)
|
||||
assert(bitsetXor.nextSetBit(2) === 2)
|
||||
assert(bitsetXor.nextSetBit(3) === 38)
|
||||
assert(bitsetXor.nextSetBit(38) === 38)
|
||||
assert(bitsetXor.nextSetBit(39) === 85)
|
||||
assert(bitsetXor.nextSetBit(42) === 85)
|
||||
assert(bitsetXor.nextSetBit(85) === 85)
|
||||
assert(bitsetXor.nextSetBit(86) === -1)
|
||||
|
||||
}
|
||||
|
||||
test( "xor len(bitsetX) > len(bitsetY)" ) {
|
||||
val bitsetX = new ImmutableBitSet(100, 0, 1, 3, 37, 38, 41, 85)
|
||||
val bitsetY = new ImmutableBitSet(60, 0, 2, 3, 37, 41)
|
||||
|
||||
val bitsetXor = bitsetX ^ bitsetY
|
||||
|
||||
assert(bitsetXor.nextSetBit(0) === 1)
|
||||
assert(bitsetXor.nextSetBit(1) === 1)
|
||||
assert(bitsetXor.nextSetBit(2) === 2)
|
||||
assert(bitsetXor.nextSetBit(3) === 38)
|
||||
assert(bitsetXor.nextSetBit(38) === 38)
|
||||
assert(bitsetXor.nextSetBit(39) === 85)
|
||||
assert(bitsetXor.nextSetBit(42) === 85)
|
||||
assert(bitsetXor.nextSetBit(85) === 85)
|
||||
assert(bitsetXor.nextSetBit(86) === -1)
|
||||
|
||||
}
|
||||
|
||||
test( "andNot len(bitsetX) < len(bitsetY)" ) {
|
||||
val bitsetX = new ImmutableBitSet(60, 0, 2, 3, 37, 41, 48)
|
||||
val bitsetY = new ImmutableBitSet(100, 0, 1, 3, 37, 38, 41, 85)
|
||||
|
||||
val bitsetDiff = bitsetX.andNot( bitsetY )
|
||||
|
||||
assert(bitsetDiff.nextSetBit(0) === 2)
|
||||
assert(bitsetDiff.nextSetBit(1) === 2)
|
||||
assert(bitsetDiff.nextSetBit(2) === 2)
|
||||
assert(bitsetDiff.nextSetBit(3) === 48)
|
||||
assert(bitsetDiff.nextSetBit(48) === 48)
|
||||
assert(bitsetDiff.nextSetBit(49) === -1)
|
||||
assert(bitsetDiff.nextSetBit(65) === -1)
|
||||
}
|
||||
|
||||
test( "andNot len(bitsetX) > len(bitsetY)" ) {
|
||||
val bitsetX = new ImmutableBitSet(100, 0, 1, 3, 37, 38, 41, 85)
|
||||
val bitsetY = new ImmutableBitSet(60, 0, 2, 3, 37, 41, 48)
|
||||
|
||||
val bitsetDiff = bitsetX.andNot( bitsetY )
|
||||
|
||||
assert(bitsetDiff.nextSetBit(0) === 1)
|
||||
assert(bitsetDiff.nextSetBit(1) === 1)
|
||||
assert(bitsetDiff.nextSetBit(2) === 38)
|
||||
assert(bitsetDiff.nextSetBit(3) === 38)
|
||||
assert(bitsetDiff.nextSetBit(38) === 38)
|
||||
assert(bitsetDiff.nextSetBit(39) === 85)
|
||||
assert(bitsetDiff.nextSetBit(85) === 85)
|
||||
assert(bitsetDiff.nextSetBit(86) === -1)
|
||||
}
|
||||
|
||||
test( "immutability" ) {
|
||||
val bitset = new ImmutableBitSet(100)
|
||||
intercept[UnsupportedOperationException] {
|
||||
bitset.set(1)
|
||||
}
|
||||
intercept[UnsupportedOperationException] {
|
||||
bitset.setUntil(10)
|
||||
}
|
||||
intercept[UnsupportedOperationException] {
|
||||
bitset.unset(1)
|
||||
}
|
||||
intercept[UnsupportedOperationException] {
|
||||
bitset.clear()
|
||||
}
|
||||
intercept[UnsupportedOperationException] {
|
||||
bitset.clearUntil(10)
|
||||
}
|
||||
intercept[UnsupportedOperationException] {
|
||||
bitset.union(new ImmutableBitSet(100))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -41,6 +41,8 @@ import org.json4s.JsonAST._
|
|||
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.trees.TreePattern
|
||||
import org.apache.spark.sql.catalyst.trees.TreePattern.{LITERAL, NULL_LITERAL, TRUE_OR_FALSE_LITERAL}
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
|
||||
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
|
||||
|
@ -50,6 +52,8 @@ import org.apache.spark.sql.internal.SQLConf
|
|||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types._
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.collection.BitSet
|
||||
import org.apache.spark.util.collection.ImmutableBitSet
|
||||
|
||||
object Literal {
|
||||
val TrueLiteral: Literal = Literal(true, BooleanType)
|
||||
|
@ -296,6 +300,18 @@ object DecimalLiteral {
|
|||
def smallerThanSmallestLong(v: Decimal): Boolean = v < Decimal(Long.MinValue)
|
||||
}
|
||||
|
||||
object LiteralTreeBits {
|
||||
// Singleton tree pattern BitSet for all Literals that are not true, false, or null.
|
||||
val literalBits: BitSet = new ImmutableBitSet(TreePattern.maxId, LITERAL.id)
|
||||
|
||||
// Singleton tree pattern BitSet for all Literals that are true or false.
|
||||
val booleanLiteralBits: BitSet = new ImmutableBitSet(
|
||||
TreePattern.maxId, LITERAL.id, TRUE_OR_FALSE_LITERAL.id)
|
||||
|
||||
// Singleton tree pattern BitSet for all Literals that are nulls.
|
||||
val nullLiteralBits: BitSet = new ImmutableBitSet(TreePattern.maxId, LITERAL.id, NULL_LITERAL.id)
|
||||
}
|
||||
|
||||
/**
|
||||
* In order to do type checking, use Literal.create() instead of constructor
|
||||
*/
|
||||
|
@ -308,6 +324,14 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
|
|||
|
||||
private def timeZoneId = DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)
|
||||
|
||||
override lazy val treePatternBits: BitSet = {
|
||||
value match {
|
||||
case null => LiteralTreeBits.nullLiteralBits
|
||||
case true | false => LiteralTreeBits.booleanLiteralBits
|
||||
case _ => LiteralTreeBits.literalBits
|
||||
}
|
||||
}
|
||||
|
||||
override def toString: String = value match {
|
||||
case null => "null"
|
||||
case binary: Array[Byte] => s"0x" + DatatypeConverter.printHexBinary(binary)
|
||||
|
|
|
@ -23,9 +23,13 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
|
||||
import org.apache.spark.sql.catalyst.trees.TreePattern
|
||||
import org.apache.spark.sql.catalyst.trees.TreePattern.ATTRIBUTE_REFERENCE
|
||||
import org.apache.spark.sql.catalyst.util.quoteIfNeeded
|
||||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.collection.BitSet
|
||||
import org.apache.spark.util.collection.ImmutableBitSet
|
||||
|
||||
object NamedExpression {
|
||||
private val curId = new java.util.concurrent.atomic.AtomicLong()
|
||||
|
@ -231,6 +235,11 @@ case class Alias(child: Expression, name: String)(
|
|||
copy(child = newChild)(exprId, qualifier, explicitMetadata, nonInheritableMetadataKeys)
|
||||
}
|
||||
|
||||
// Singleton tree pattern BitSet for all AttributeReference instances.
|
||||
object AttributeReferenceTreeBits {
|
||||
val bits: BitSet = new ImmutableBitSet(TreePattern.maxId, ATTRIBUTE_REFERENCE.id)
|
||||
}
|
||||
|
||||
/**
|
||||
* A reference to an attribute produced by another operator in the tree.
|
||||
*
|
||||
|
@ -253,6 +262,8 @@ case class AttributeReference(
|
|||
val qualifier: Seq[String] = Seq.empty[String])
|
||||
extends Attribute with Unevaluable {
|
||||
|
||||
override lazy val treePatternBits: BitSet = AttributeReferenceTreeBits.bits
|
||||
|
||||
/**
|
||||
* Returns true iff the expression id is the same for both attributes.
|
||||
*/
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter,
|
|||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, InsertStarAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction, UpdateTable}
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_LITERAL, TRUE_OR_FALSE_LITERAL}
|
||||
import org.apache.spark.sql.types.BooleanType
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
@ -49,7 +50,8 @@ import org.apache.spark.util.Utils
|
|||
*/
|
||||
object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
|
||||
_.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL), ruleId) {
|
||||
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
|
||||
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
|
||||
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
|
||||
|
@ -59,7 +61,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
|
|||
mergeCondition = replaceNullWithFalse(mergeCond),
|
||||
matchedActions = replaceNullWithFalse(matchedActions),
|
||||
notMatchedActions = replaceNullWithFalse(notMatchedActions))
|
||||
case p: LogicalPlan => p transformExpressions {
|
||||
case p: LogicalPlan => p.transformExpressionsWithPruning(
|
||||
_.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL), ruleId) {
|
||||
// For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no
|
||||
// difference, as `null <=> true` and `false <=> true` both return false.
|
||||
case EqualNullSafe(left, TrueLiteral) =>
|
||||
|
|
|
@ -51,7 +51,8 @@ object RuleIdCollection {
|
|||
"org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
|
||||
"org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
|
||||
"org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" ::
|
||||
"org.apache.spark.sql.catalyst.optimizer.ReorderJoin" :: Nil
|
||||
"org.apache.spark.sql.catalyst.optimizer.ReorderJoin" ::
|
||||
"org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" :: Nil
|
||||
}
|
||||
|
||||
// Maps rule names to ids. Rule ids are continuous natural numbers starting from 0.
|
||||
|
|
|
@ -23,8 +23,12 @@ object TreePattern extends Enumeration {
|
|||
|
||||
// Enum Ids start from 0.
|
||||
// Expression patterns (alphabetically ordered)
|
||||
val EXPRESSION_WITH_RANDOM_SEED = Value(0)
|
||||
val ATTRIBUTE_REFERENCE = Value(0)
|
||||
val EXPRESSION_WITH_RANDOM_SEED = Value
|
||||
val IN: Value = Value
|
||||
val LITERAL: Value = Value
|
||||
val NULL_LITERAL: Value = Value
|
||||
val TRUE_OR_FALSE_LITERAL: Value = Value
|
||||
val WINDOW_EXPRESSION: Value = Value
|
||||
|
||||
// Logical plan patterns (alphabetically ordered)
|
||||
|
|
Loading…
Reference in a new issue