[SPARK-5287][SQL] Add defaultSizeOf to every data type.

JIRA: https://issues.apache.org/jira/browse/SPARK-5287

This PR only add `defaultSizeOf` to data types and make those internal type classes `protected[sql]`. I will use another PR to cleanup the type hierarchy of data types.

Author: Yin Huai <yhuai@databricks.com>

Closes #4081 from yhuai/SPARK-5287 and squashes the following commits:

90cec75 [Yin Huai] Update unit test.
e1c600c [Yin Huai] Make internal classes protected[sql].
7eaba68 [Yin Huai] Add `defaultSize` method to data types.
fd425e0 [Yin Huai] Add all native types to NativeType.defaultSizeOf.
This commit is contained in:
Yin Huai 2015-01-20 13:26:36 -08:00 committed by Reynold Xin
parent 23e25543be
commit bc20a52b34
5 changed files with 199 additions and 46 deletions

View file

@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType
/**
* The data type representing [[DynamicRow]] values.
*/
case object DynamicType extends DataType
case object DynamicType extends DataType {
/**
* The default size of a value of the DynamicType is 4096 bytes.
*/
override def defaultSize: Int = 4096
}
/**
* Wrap a [[Row]] as a [[DynamicRow]].

View file

@ -238,16 +238,11 @@ case class Rollup(
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output = child.output
override lazy val statistics: Statistics =
if (output.forall(_.dataType.isInstanceOf[NativeType])) {
val limit = limitExpr.eval(null).asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map { a =>
NativeType.defaultSizeOf(a.dataType.asInstanceOf[NativeType])
}.sum
Statistics(sizeInBytes = sizeInBytes)
} else {
Statistics(sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product)
}
override lazy val statistics: Statistics = {
val limit = limitExpr.eval(null).asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
Statistics(sizeInBytes = sizeInBytes)
}
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {

View file

@ -215,6 +215,9 @@ abstract class DataType {
case _ => false
}
/** The default size of a value of this data type. */
def defaultSize: Int
def isPrimitive: Boolean = false
def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
@ -235,33 +238,25 @@ abstract class DataType {
* @group dataType
*/
@DeveloperApi
case object NullType extends DataType
case object NullType extends DataType {
override def defaultSize: Int = 1
}
object NativeType {
protected[sql] object NativeType {
val all = Seq(
IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
def unapply(dt: DataType): Boolean = all.contains(dt)
val defaultSizeOf: Map[NativeType, Int] = Map(
IntegerType -> 4,
BooleanType -> 1,
LongType -> 8,
DoubleType -> 8,
FloatType -> 4,
ShortType -> 2,
ByteType -> 1,
StringType -> 4096)
}
trait PrimitiveType extends DataType {
protected[sql] trait PrimitiveType extends DataType {
override def isPrimitive = true
}
object PrimitiveType {
protected[sql] object PrimitiveType {
private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all
private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap
@ -276,7 +271,7 @@ object PrimitiveType {
}
}
abstract class NativeType extends DataType {
protected[sql] abstract class NativeType extends DataType {
private[sql] type JvmType
@transient private[sql] val tag: TypeTag[JvmType]
private[sql] val ordering: Ordering[JvmType]
@ -300,6 +295,11 @@ case object StringType extends NativeType with PrimitiveType {
private[sql] type JvmType = String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
/**
* The default size of a value of the StringType is 4096 bytes.
*/
override def defaultSize: Int = 4096
}
@ -324,6 +324,11 @@ case object BinaryType extends NativeType with PrimitiveType {
x.length - y.length
}
}
/**
* The default size of a value of the BinaryType is 4096 bytes.
*/
override def defaultSize: Int = 4096
}
@ -339,6 +344,11 @@ case object BooleanType extends NativeType with PrimitiveType {
private[sql] type JvmType = Boolean
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
/**
* The default size of a value of the BooleanType is 1 byte.
*/
override def defaultSize: Int = 1
}
@ -359,6 +369,11 @@ case object TimestampType extends NativeType {
private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
}
/**
* The default size of a value of the TimestampType is 8 bytes.
*/
override def defaultSize: Int = 8
}
@ -379,10 +394,15 @@ case object DateType extends NativeType {
private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Date, y: Date) = x.compareTo(y)
}
/**
* The default size of a value of the DateType is 8 bytes.
*/
override def defaultSize: Int = 8
}
abstract class NumericType extends NativeType with PrimitiveType {
protected[sql] abstract class NumericType extends NativeType with PrimitiveType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
// type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets
@ -392,13 +412,13 @@ abstract class NumericType extends NativeType with PrimitiveType {
}
object NumericType {
protected[sql] object NumericType {
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
}
/** Matcher for any expressions that evaluate to [[IntegralType]]s */
object IntegralType {
protected[sql] object IntegralType {
def unapply(a: Expression): Boolean = a match {
case e: Expression if e.dataType.isInstanceOf[IntegralType] => true
case _ => false
@ -406,7 +426,7 @@ object IntegralType {
}
sealed abstract class IntegralType extends NumericType {
protected[sql] sealed abstract class IntegralType extends NumericType {
private[sql] val integral: Integral[JvmType]
}
@ -425,6 +445,11 @@ case object LongType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Long]]
private[sql] val integral = implicitly[Integral[Long]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
/**
* The default size of a value of the LongType is 8 bytes.
*/
override def defaultSize: Int = 8
}
@ -442,6 +467,11 @@ case object IntegerType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Int]]
private[sql] val integral = implicitly[Integral[Int]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
/**
* The default size of a value of the IntegerType is 4 bytes.
*/
override def defaultSize: Int = 4
}
@ -459,6 +489,11 @@ case object ShortType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Short]]
private[sql] val integral = implicitly[Integral[Short]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
/**
* The default size of a value of the ShortType is 2 bytes.
*/
override def defaultSize: Int = 2
}
@ -476,11 +511,16 @@ case object ByteType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Byte]]
private[sql] val integral = implicitly[Integral[Byte]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
/**
* The default size of a value of the ByteType is 1 byte.
*/
override def defaultSize: Int = 1
}
/** Matcher for any expressions that evaluate to [[FractionalType]]s */
object FractionalType {
protected[sql] object FractionalType {
def unapply(a: Expression): Boolean = a match {
case e: Expression if e.dataType.isInstanceOf[FractionalType] => true
case _ => false
@ -488,7 +528,7 @@ object FractionalType {
}
sealed abstract class FractionalType extends NumericType {
protected[sql] sealed abstract class FractionalType extends NumericType {
private[sql] val fractional: Fractional[JvmType]
private[sql] val asIntegral: Integral[JvmType]
}
@ -530,6 +570,11 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
case None => "DecimalType()"
}
/**
* The default size of a value of the DecimalType is 4096 bytes.
*/
override def defaultSize: Int = 4096
}
@ -580,6 +625,11 @@ case object DoubleType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = DoubleAsIfIntegral
/**
* The default size of a value of the DoubleType is 8 bytes.
*/
override def defaultSize: Int = 8
}
@ -598,6 +648,11 @@ case object FloatType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = FloatAsIfIntegral
/**
* The default size of a value of the FloatType is 4 bytes.
*/
override def defaultSize: Int = 4
}
@ -636,6 +691,12 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
("type" -> typeName) ~
("elementType" -> elementType.jsonValue) ~
("containsNull" -> containsNull)
/**
* The default size of a value of the ArrayType is 100 * the default size of the element type.
* (We assume that there are 100 elements).
*/
override def defaultSize: Int = 100 * elementType.defaultSize
}
@ -805,6 +866,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
override def length: Int = fields.length
override def iterator: Iterator[StructField] = fields.iterator
/**
* The default size of a value of the StructType is the total default sizes of all field types.
*/
override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum
}
@ -848,6 +914,13 @@ case class MapType(
("keyType" -> keyType.jsonValue) ~
("valueType" -> valueType.jsonValue) ~
("valueContainsNull" -> valueContainsNull)
/**
* The default size of a value of the MapType is
* 100 * (the default size of the key type + the default size of the value type).
* (We assume that there are 100 elements).
*/
override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
}
@ -896,4 +969,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
* Class object for the UserType
*/
def userClass: java.lang.Class[UserType]
/**
* The default size of a value of the UserDefinedType is 4096 bytes.
*/
override def defaultSize: Int = 4096
}

View file

@ -62,6 +62,7 @@ class DataTypeSuite extends FunSuite {
}
}
checkDataTypeJsonRepr(NullType)
checkDataTypeJsonRepr(BooleanType)
checkDataTypeJsonRepr(ByteType)
checkDataTypeJsonRepr(ShortType)
@ -69,7 +70,9 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(LongType)
checkDataTypeJsonRepr(FloatType)
checkDataTypeJsonRepr(DoubleType)
checkDataTypeJsonRepr(DecimalType(10, 5))
checkDataTypeJsonRepr(DecimalType.Unlimited)
checkDataTypeJsonRepr(DateType)
checkDataTypeJsonRepr(TimestampType)
checkDataTypeJsonRepr(StringType)
checkDataTypeJsonRepr(BinaryType)
@ -77,12 +80,39 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(ArrayType(StringType, false))
checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
val metadata = new MetadataBuilder()
.putString("name", "age")
.build()
checkDataTypeJsonRepr(
StructType(Seq(
StructField("a", IntegerType, nullable = true),
StructField("b", ArrayType(DoubleType), nullable = false),
StructField("c", DoubleType, nullable = false, metadata))))
val structType = StructType(Seq(
StructField("a", IntegerType, nullable = true),
StructField("b", ArrayType(DoubleType), nullable = false),
StructField("c", DoubleType, nullable = false, metadata)))
checkDataTypeJsonRepr(structType)
def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = {
test(s"Check the default size of ${dataType}") {
assert(dataType.defaultSize === expectedDefaultSize)
}
}
checkDefaultSize(NullType, 1)
checkDefaultSize(BooleanType, 1)
checkDefaultSize(ByteType, 1)
checkDefaultSize(ShortType, 2)
checkDefaultSize(IntegerType, 4)
checkDefaultSize(LongType, 8)
checkDefaultSize(FloatType, 4)
checkDefaultSize(DoubleType, 8)
checkDefaultSize(DecimalType(10, 5), 4096)
checkDefaultSize(DecimalType.Unlimited, 4096)
checkDefaultSize(DateType, 8)
checkDefaultSize(TimestampType, 8)
checkDefaultSize(StringType, 4096)
checkDefaultSize(BinaryType, 4096)
checkDefaultSize(ArrayType(DoubleType, true), 800)
checkDefaultSize(ArrayType(StringType, false), 409600)
checkDefaultSize(MapType(IntegerType, StringType, true), 410000)
checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400)
checkDefaultSize(structType, 812)
}

View file

@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
class PlannerSuite extends FunSuite {
test("unions are collapsed") {
@ -60,19 +61,62 @@ class PlannerSuite extends FunSuite {
}
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString)
val fields = fieldTypes.zipWithIndex.map {
case (dataType, index) => StructField(s"c${index}", dataType, true)
} :+ StructField("key", IntegerType, true)
val schema = StructType(fields)
val row = Row.fromSeq(Seq.fill(fields.size)(null))
val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
applySchema(rowRDD, schema).registerTempTable("testLimit")
val planned = sql(
"""
|SELECT l.a, l.b
|FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key)
""".stripMargin).queryExecution.executedPlan
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
dropTempTable("testLimit")
}
val origThreshold = conf.autoBroadcastJoinThreshold
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString)
// Using a threshold that is definitely larger than the small testing table (b) below
val a = testData.as('a)
val b = testData.limit(3).as('b)
val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan
val simpleTypes =
NullType ::
BooleanType ::
ByteType ::
ShortType ::
IntegerType ::
LongType ::
FloatType ::
DoubleType ::
DecimalType(10, 5) ::
DecimalType.Unlimited ::
DateType ::
TimestampType ::
StringType ::
BinaryType :: Nil
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
checkPlan(simpleTypes, newThreshold = 16434)
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
val complexTypes =
ArrayType(DoubleType, true) ::
ArrayType(StringType, false) ::
MapType(IntegerType, StringType, true) ::
MapType(IntegerType, ArrayType(DoubleType), false) ::
StructType(Seq(
StructField("a", IntegerType, nullable = true),
StructField("b", ArrayType(DoubleType), nullable = false),
StructField("c", DoubleType, nullable = false))) :: Nil
checkPlan(complexTypes, newThreshold = 901617)
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
}