diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index c3a2675ee5..09864e3f83 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -36,9 +36,14 @@ import org.apache.spark.util.collection.OpenHashSet * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first. * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work. + * The difference between a [[KnownSizeEstimation]] and + * [[org.apache.spark.util.collection.SizeTracker]] is that, a + * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to + * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without + * using [[SizeEstimator]]. */ -private[spark] trait SizeEstimation { - def estimatedSize: Option[Long] +private[spark] trait KnownSizeEstimation { + def estimatedSize: Long } /** @@ -209,18 +214,15 @@ object SizeEstimator extends Logging { // the size estimator since it references the whole REPL. Do nothing in this case. In // general all ClassLoaders and Classes will be shared between objects anyway. } else { - val estimatedSize = obj match { - case s: SizeEstimation => s.estimatedSize - case _ => None - } - if (estimatedSize.isDefined) { - state.size += estimatedSize.get - } else { - val classInfo = getClassInfo(cls) - state.size += alignSize(classInfo.shellSize) - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) - } + obj match { + case s: KnownSizeEstimation => + state.size += s.estimatedSize + case _ => + val classInfo = getClassInfo(cls) + state.size += alignSize(classInfo.shellSize) + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 9b6261af12..101610e380 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -60,16 +60,10 @@ class DummyString(val arr: Array[Char]) { @transient val hash32: Int = 0 } -class DummyClass8 extends SizeEstimation { +class DummyClass8 extends KnownSizeEstimation { val x: Int = 0 - override def estimatedSize: Option[Long] = Some(2015) -} - -class DummyClass9 extends SizeEstimation { - val x: Int = 0 - - override def estimatedSize: Option[Long] = None + override def estimatedSize: Long = 2015 } class SizeEstimatorSuite @@ -231,9 +225,5 @@ class SizeEstimatorSuite // DummyClass8 provides its size estimation. assertResult(2015)(SizeEstimator.estimate(new DummyClass8)) assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8))) - - // DummyClass9 does not provide its size estimation. - assertResult(16)(SizeEstimator.estimate(new DummyClass9)) - assertResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass9))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 49ae09bf53..aebfea5832 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.MemoryLocation -import org.apache.spark.util.{SizeEstimation, Utils} +import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -190,7 +190,7 @@ private[execution] object HashedRelation { private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation - with SizeEstimation + with KnownSizeEstimation with Externalizable { private[joins] def this() = this(null) // Needed for serialization @@ -217,8 +217,12 @@ private[joins] final class UnsafeHashedRelation( } } - override def estimatedSize: Option[Long] = { - Option(binaryMap).map(_.getTotalMemoryConsumption) + override def estimatedSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + SizeEstimator.estimate(hashTable) + } } override def get(key: InternalRow): Seq[InternalRow] = {