[SPARK-5287][SQL] Add defaultSizeOf to every data type.
JIRA: https://issues.apache.org/jira/browse/SPARK-5287 This PR only add `defaultSizeOf` to data types and make those internal type classes `protected[sql]`. I will use another PR to cleanup the type hierarchy of data types. Author: Yin Huai <yhuai@databricks.com> Closes #4081 from yhuai/SPARK-5287 and squashes the following commits: 90cec75 [Yin Huai] Update unit test. e1c600c [Yin Huai] Make internal classes protected[sql]. 7eaba68 [Yin Huai] Add `defaultSize` method to data types. fd425e0 [Yin Huai] Add all native types to NativeType.defaultSizeOf.
This commit is contained in:
parent
23e25543be
commit
bc20a52b34
|
@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType
|
|||
/**
|
||||
* The data type representing [[DynamicRow]] values.
|
||||
*/
|
||||
case object DynamicType extends DataType
|
||||
case object DynamicType extends DataType {
|
||||
|
||||
/**
|
||||
* The default size of a value of the DynamicType is 4096 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 4096
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap a [[Row]] as a [[DynamicRow]].
|
||||
|
|
|
@ -238,16 +238,11 @@ case class Rollup(
|
|||
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
|
||||
override def output = child.output
|
||||
|
||||
override lazy val statistics: Statistics =
|
||||
if (output.forall(_.dataType.isInstanceOf[NativeType])) {
|
||||
val limit = limitExpr.eval(null).asInstanceOf[Int]
|
||||
val sizeInBytes = (limit: Long) * output.map { a =>
|
||||
NativeType.defaultSizeOf(a.dataType.asInstanceOf[NativeType])
|
||||
}.sum
|
||||
Statistics(sizeInBytes = sizeInBytes)
|
||||
} else {
|
||||
Statistics(sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product)
|
||||
}
|
||||
override lazy val statistics: Statistics = {
|
||||
val limit = limitExpr.eval(null).asInstanceOf[Int]
|
||||
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
|
||||
Statistics(sizeInBytes = sizeInBytes)
|
||||
}
|
||||
}
|
||||
|
||||
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
|
||||
|
|
|
@ -215,6 +215,9 @@ abstract class DataType {
|
|||
case _ => false
|
||||
}
|
||||
|
||||
/** The default size of a value of this data type. */
|
||||
def defaultSize: Int
|
||||
|
||||
def isPrimitive: Boolean = false
|
||||
|
||||
def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
|
||||
|
@ -235,33 +238,25 @@ abstract class DataType {
|
|||
* @group dataType
|
||||
*/
|
||||
@DeveloperApi
|
||||
case object NullType extends DataType
|
||||
case object NullType extends DataType {
|
||||
override def defaultSize: Int = 1
|
||||
}
|
||||
|
||||
|
||||
object NativeType {
|
||||
protected[sql] object NativeType {
|
||||
val all = Seq(
|
||||
IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
|
||||
|
||||
def unapply(dt: DataType): Boolean = all.contains(dt)
|
||||
|
||||
val defaultSizeOf: Map[NativeType, Int] = Map(
|
||||
IntegerType -> 4,
|
||||
BooleanType -> 1,
|
||||
LongType -> 8,
|
||||
DoubleType -> 8,
|
||||
FloatType -> 4,
|
||||
ShortType -> 2,
|
||||
ByteType -> 1,
|
||||
StringType -> 4096)
|
||||
}
|
||||
|
||||
|
||||
trait PrimitiveType extends DataType {
|
||||
protected[sql] trait PrimitiveType extends DataType {
|
||||
override def isPrimitive = true
|
||||
}
|
||||
|
||||
|
||||
object PrimitiveType {
|
||||
protected[sql] object PrimitiveType {
|
||||
private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all
|
||||
private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap
|
||||
|
||||
|
@ -276,7 +271,7 @@ object PrimitiveType {
|
|||
}
|
||||
}
|
||||
|
||||
abstract class NativeType extends DataType {
|
||||
protected[sql] abstract class NativeType extends DataType {
|
||||
private[sql] type JvmType
|
||||
@transient private[sql] val tag: TypeTag[JvmType]
|
||||
private[sql] val ordering: Ordering[JvmType]
|
||||
|
@ -300,6 +295,11 @@ case object StringType extends NativeType with PrimitiveType {
|
|||
private[sql] type JvmType = String
|
||||
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
|
||||
private[sql] val ordering = implicitly[Ordering[JvmType]]
|
||||
|
||||
/**
|
||||
* The default size of a value of the StringType is 4096 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 4096
|
||||
}
|
||||
|
||||
|
||||
|
@ -324,6 +324,11 @@ case object BinaryType extends NativeType with PrimitiveType {
|
|||
x.length - y.length
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The default size of a value of the BinaryType is 4096 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 4096
|
||||
}
|
||||
|
||||
|
||||
|
@ -339,6 +344,11 @@ case object BooleanType extends NativeType with PrimitiveType {
|
|||
private[sql] type JvmType = Boolean
|
||||
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
|
||||
private[sql] val ordering = implicitly[Ordering[JvmType]]
|
||||
|
||||
/**
|
||||
* The default size of a value of the BooleanType is 1 byte.
|
||||
*/
|
||||
override def defaultSize: Int = 1
|
||||
}
|
||||
|
||||
|
||||
|
@ -359,6 +369,11 @@ case object TimestampType extends NativeType {
|
|||
private[sql] val ordering = new Ordering[JvmType] {
|
||||
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
|
||||
}
|
||||
|
||||
/**
|
||||
* The default size of a value of the TimestampType is 8 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 8
|
||||
}
|
||||
|
||||
|
||||
|
@ -379,10 +394,15 @@ case object DateType extends NativeType {
|
|||
private[sql] val ordering = new Ordering[JvmType] {
|
||||
def compare(x: Date, y: Date) = x.compareTo(y)
|
||||
}
|
||||
|
||||
/**
|
||||
* The default size of a value of the DateType is 8 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 8
|
||||
}
|
||||
|
||||
|
||||
abstract class NumericType extends NativeType with PrimitiveType {
|
||||
protected[sql] abstract class NumericType extends NativeType with PrimitiveType {
|
||||
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
|
||||
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
|
||||
// type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets
|
||||
|
@ -392,13 +412,13 @@ abstract class NumericType extends NativeType with PrimitiveType {
|
|||
}
|
||||
|
||||
|
||||
object NumericType {
|
||||
protected[sql] object NumericType {
|
||||
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
|
||||
}
|
||||
|
||||
|
||||
/** Matcher for any expressions that evaluate to [[IntegralType]]s */
|
||||
object IntegralType {
|
||||
protected[sql] object IntegralType {
|
||||
def unapply(a: Expression): Boolean = a match {
|
||||
case e: Expression if e.dataType.isInstanceOf[IntegralType] => true
|
||||
case _ => false
|
||||
|
@ -406,7 +426,7 @@ object IntegralType {
|
|||
}
|
||||
|
||||
|
||||
sealed abstract class IntegralType extends NumericType {
|
||||
protected[sql] sealed abstract class IntegralType extends NumericType {
|
||||
private[sql] val integral: Integral[JvmType]
|
||||
}
|
||||
|
||||
|
@ -425,6 +445,11 @@ case object LongType extends IntegralType {
|
|||
private[sql] val numeric = implicitly[Numeric[Long]]
|
||||
private[sql] val integral = implicitly[Integral[Long]]
|
||||
private[sql] val ordering = implicitly[Ordering[JvmType]]
|
||||
|
||||
/**
|
||||
* The default size of a value of the LongType is 8 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 8
|
||||
}
|
||||
|
||||
|
||||
|
@ -442,6 +467,11 @@ case object IntegerType extends IntegralType {
|
|||
private[sql] val numeric = implicitly[Numeric[Int]]
|
||||
private[sql] val integral = implicitly[Integral[Int]]
|
||||
private[sql] val ordering = implicitly[Ordering[JvmType]]
|
||||
|
||||
/**
|
||||
* The default size of a value of the IntegerType is 4 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 4
|
||||
}
|
||||
|
||||
|
||||
|
@ -459,6 +489,11 @@ case object ShortType extends IntegralType {
|
|||
private[sql] val numeric = implicitly[Numeric[Short]]
|
||||
private[sql] val integral = implicitly[Integral[Short]]
|
||||
private[sql] val ordering = implicitly[Ordering[JvmType]]
|
||||
|
||||
/**
|
||||
* The default size of a value of the ShortType is 2 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 2
|
||||
}
|
||||
|
||||
|
||||
|
@ -476,11 +511,16 @@ case object ByteType extends IntegralType {
|
|||
private[sql] val numeric = implicitly[Numeric[Byte]]
|
||||
private[sql] val integral = implicitly[Integral[Byte]]
|
||||
private[sql] val ordering = implicitly[Ordering[JvmType]]
|
||||
|
||||
/**
|
||||
* The default size of a value of the ByteType is 1 byte.
|
||||
*/
|
||||
override def defaultSize: Int = 1
|
||||
}
|
||||
|
||||
|
||||
/** Matcher for any expressions that evaluate to [[FractionalType]]s */
|
||||
object FractionalType {
|
||||
protected[sql] object FractionalType {
|
||||
def unapply(a: Expression): Boolean = a match {
|
||||
case e: Expression if e.dataType.isInstanceOf[FractionalType] => true
|
||||
case _ => false
|
||||
|
@ -488,7 +528,7 @@ object FractionalType {
|
|||
}
|
||||
|
||||
|
||||
sealed abstract class FractionalType extends NumericType {
|
||||
protected[sql] sealed abstract class FractionalType extends NumericType {
|
||||
private[sql] val fractional: Fractional[JvmType]
|
||||
private[sql] val asIntegral: Integral[JvmType]
|
||||
}
|
||||
|
@ -530,6 +570,11 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
|
|||
case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
|
||||
case None => "DecimalType()"
|
||||
}
|
||||
|
||||
/**
|
||||
* The default size of a value of the DecimalType is 4096 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 4096
|
||||
}
|
||||
|
||||
|
||||
|
@ -580,6 +625,11 @@ case object DoubleType extends FractionalType {
|
|||
private[sql] val fractional = implicitly[Fractional[Double]]
|
||||
private[sql] val ordering = implicitly[Ordering[JvmType]]
|
||||
private[sql] val asIntegral = DoubleAsIfIntegral
|
||||
|
||||
/**
|
||||
* The default size of a value of the DoubleType is 8 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 8
|
||||
}
|
||||
|
||||
|
||||
|
@ -598,6 +648,11 @@ case object FloatType extends FractionalType {
|
|||
private[sql] val fractional = implicitly[Fractional[Float]]
|
||||
private[sql] val ordering = implicitly[Ordering[JvmType]]
|
||||
private[sql] val asIntegral = FloatAsIfIntegral
|
||||
|
||||
/**
|
||||
* The default size of a value of the FloatType is 4 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 4
|
||||
}
|
||||
|
||||
|
||||
|
@ -636,6 +691,12 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
|
|||
("type" -> typeName) ~
|
||||
("elementType" -> elementType.jsonValue) ~
|
||||
("containsNull" -> containsNull)
|
||||
|
||||
/**
|
||||
* The default size of a value of the ArrayType is 100 * the default size of the element type.
|
||||
* (We assume that there are 100 elements).
|
||||
*/
|
||||
override def defaultSize: Int = 100 * elementType.defaultSize
|
||||
}
|
||||
|
||||
|
||||
|
@ -805,6 +866,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
|
|||
override def length: Int = fields.length
|
||||
|
||||
override def iterator: Iterator[StructField] = fields.iterator
|
||||
|
||||
/**
|
||||
* The default size of a value of the StructType is the total default sizes of all field types.
|
||||
*/
|
||||
override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum
|
||||
}
|
||||
|
||||
|
||||
|
@ -848,6 +914,13 @@ case class MapType(
|
|||
("keyType" -> keyType.jsonValue) ~
|
||||
("valueType" -> valueType.jsonValue) ~
|
||||
("valueContainsNull" -> valueContainsNull)
|
||||
|
||||
/**
|
||||
* The default size of a value of the MapType is
|
||||
* 100 * (the default size of the key type + the default size of the value type).
|
||||
* (We assume that there are 100 elements).
|
||||
*/
|
||||
override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
|
||||
}
|
||||
|
||||
|
||||
|
@ -896,4 +969,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
|
|||
* Class object for the UserType
|
||||
*/
|
||||
def userClass: java.lang.Class[UserType]
|
||||
|
||||
/**
|
||||
* The default size of a value of the UserDefinedType is 4096 bytes.
|
||||
*/
|
||||
override def defaultSize: Int = 4096
|
||||
}
|
||||
|
|
|
@ -62,6 +62,7 @@ class DataTypeSuite extends FunSuite {
|
|||
}
|
||||
}
|
||||
|
||||
checkDataTypeJsonRepr(NullType)
|
||||
checkDataTypeJsonRepr(BooleanType)
|
||||
checkDataTypeJsonRepr(ByteType)
|
||||
checkDataTypeJsonRepr(ShortType)
|
||||
|
@ -69,7 +70,9 @@ class DataTypeSuite extends FunSuite {
|
|||
checkDataTypeJsonRepr(LongType)
|
||||
checkDataTypeJsonRepr(FloatType)
|
||||
checkDataTypeJsonRepr(DoubleType)
|
||||
checkDataTypeJsonRepr(DecimalType(10, 5))
|
||||
checkDataTypeJsonRepr(DecimalType.Unlimited)
|
||||
checkDataTypeJsonRepr(DateType)
|
||||
checkDataTypeJsonRepr(TimestampType)
|
||||
checkDataTypeJsonRepr(StringType)
|
||||
checkDataTypeJsonRepr(BinaryType)
|
||||
|
@ -77,12 +80,39 @@ class DataTypeSuite extends FunSuite {
|
|||
checkDataTypeJsonRepr(ArrayType(StringType, false))
|
||||
checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
|
||||
checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
|
||||
|
||||
val metadata = new MetadataBuilder()
|
||||
.putString("name", "age")
|
||||
.build()
|
||||
checkDataTypeJsonRepr(
|
||||
StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", ArrayType(DoubleType), nullable = false),
|
||||
StructField("c", DoubleType, nullable = false, metadata))))
|
||||
val structType = StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", ArrayType(DoubleType), nullable = false),
|
||||
StructField("c", DoubleType, nullable = false, metadata)))
|
||||
checkDataTypeJsonRepr(structType)
|
||||
|
||||
def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = {
|
||||
test(s"Check the default size of ${dataType}") {
|
||||
assert(dataType.defaultSize === expectedDefaultSize)
|
||||
}
|
||||
}
|
||||
|
||||
checkDefaultSize(NullType, 1)
|
||||
checkDefaultSize(BooleanType, 1)
|
||||
checkDefaultSize(ByteType, 1)
|
||||
checkDefaultSize(ShortType, 2)
|
||||
checkDefaultSize(IntegerType, 4)
|
||||
checkDefaultSize(LongType, 8)
|
||||
checkDefaultSize(FloatType, 4)
|
||||
checkDefaultSize(DoubleType, 8)
|
||||
checkDefaultSize(DecimalType(10, 5), 4096)
|
||||
checkDefaultSize(DecimalType.Unlimited, 4096)
|
||||
checkDefaultSize(DateType, 8)
|
||||
checkDefaultSize(TimestampType, 8)
|
||||
checkDefaultSize(StringType, 4096)
|
||||
checkDefaultSize(BinaryType, 4096)
|
||||
checkDefaultSize(ArrayType(DoubleType, true), 800)
|
||||
checkDefaultSize(ArrayType(StringType, false), 409600)
|
||||
checkDefaultSize(MapType(IntegerType, StringType, true), 410000)
|
||||
checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400)
|
||||
checkDefaultSize(structType, 812)
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
|
|||
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
import org.apache.spark.sql.test.TestSQLContext.planner._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class PlannerSuite extends FunSuite {
|
||||
test("unions are collapsed") {
|
||||
|
@ -60,19 +61,62 @@ class PlannerSuite extends FunSuite {
|
|||
}
|
||||
|
||||
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
|
||||
def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
|
||||
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString)
|
||||
val fields = fieldTypes.zipWithIndex.map {
|
||||
case (dataType, index) => StructField(s"c${index}", dataType, true)
|
||||
} :+ StructField("key", IntegerType, true)
|
||||
val schema = StructType(fields)
|
||||
val row = Row.fromSeq(Seq.fill(fields.size)(null))
|
||||
val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
|
||||
applySchema(rowRDD, schema).registerTempTable("testLimit")
|
||||
|
||||
val planned = sql(
|
||||
"""
|
||||
|SELECT l.a, l.b
|
||||
|FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key)
|
||||
""".stripMargin).queryExecution.executedPlan
|
||||
|
||||
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
|
||||
val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
|
||||
|
||||
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
|
||||
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
|
||||
|
||||
dropTempTable("testLimit")
|
||||
}
|
||||
|
||||
val origThreshold = conf.autoBroadcastJoinThreshold
|
||||
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString)
|
||||
|
||||
// Using a threshold that is definitely larger than the small testing table (b) below
|
||||
val a = testData.as('a)
|
||||
val b = testData.limit(3).as('b)
|
||||
val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan
|
||||
val simpleTypes =
|
||||
NullType ::
|
||||
BooleanType ::
|
||||
ByteType ::
|
||||
ShortType ::
|
||||
IntegerType ::
|
||||
LongType ::
|
||||
FloatType ::
|
||||
DoubleType ::
|
||||
DecimalType(10, 5) ::
|
||||
DecimalType.Unlimited ::
|
||||
DateType ::
|
||||
TimestampType ::
|
||||
StringType ::
|
||||
BinaryType :: Nil
|
||||
|
||||
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
|
||||
val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
|
||||
checkPlan(simpleTypes, newThreshold = 16434)
|
||||
|
||||
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
|
||||
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
|
||||
val complexTypes =
|
||||
ArrayType(DoubleType, true) ::
|
||||
ArrayType(StringType, false) ::
|
||||
MapType(IntegerType, StringType, true) ::
|
||||
MapType(IntegerType, ArrayType(DoubleType), false) ::
|
||||
StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", ArrayType(DoubleType), nullable = false),
|
||||
StructField("c", DoubleType, nullable = false))) :: Nil
|
||||
|
||||
checkPlan(complexTypes, newThreshold = 901617)
|
||||
|
||||
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue