[SPARK-11645][SQL] Remove OpenHashSet for the old aggregate.

Author: Reynold Xin <rxin@databricks.com>

Closes #9621 from rxin/SPARK-11645.
This commit is contained in:
Reynold Xin 2015-11-11 12:48:51 -08:00
parent df97df2b39
commit a9a6b80c71
5 changed files with 5 additions and 316 deletions

View file

@ -33,10 +33,6 @@ import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types._
// These classes are here to avoid issues with serialization and integration with quasiquotes.
class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
/**
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
*
@ -205,8 +201,6 @@ class CodeGenContext {
case _: StructType => "InternalRow"
case _: ArrayType => "ArrayData"
case _: MapType => "MapData"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case udt: UserDefinedType[_] => javaType(udt.sqlType)
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
case ObjectType(cls) => cls.getName

View file

@ -39,7 +39,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
case dt: OpenHashSetUDT => false // it's not a standard UDT
case udt: UserDefinedType[_] => canSupport(udt.sqlType)
case _ => false
}
@ -309,13 +308,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
in.map(BindReferences.bindReference(_, inputSchema))
def generate(
expressions: Seq[Expression],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
expressions: Seq[Expression],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
create(canonicalize(expressions), subexpressionEliminationEnabled)
}
protected def create(expressions: Seq[Expression]): UnsafeProjection = {
create(expressions, false)
create(expressions, subexpressionEliminationEnabled = false)
}
private def create(

View file

@ -1,194 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
/** The data type for expressions returning an OpenHashSet as the result. */
private[sql] class OpenHashSetUDT(
val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] {
override def sqlType: DataType = ArrayType(elementType)
/** Since we are using OpenHashSet internally, usually it will not be called. */
override def serialize(obj: Any): Seq[Any] = {
obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq
}
/** Since we are using OpenHashSet internally, usually it will not be called. */
override def deserialize(datum: Any): OpenHashSet[Any] = {
val iterator = datum.asInstanceOf[Seq[Any]].iterator
val set = new OpenHashSet[Any]
while(iterator.hasNext) {
set.add(iterator.next())
}
set
}
override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]]
private[spark] override def asNullable: OpenHashSetUDT = this
}
/**
* Creates a new set of the specified type
*/
case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback {
override def nullable: Boolean = false
override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType)
override def eval(input: InternalRow): Any = {
new OpenHashSet[Any]()
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
elementType match {
case IntegerType | LongType =>
ev.isNull = "false"
s"""
${ctx.javaType(dataType)} ${ev.value} = new ${ctx.javaType(dataType)}();
"""
case _ => super.genCode(ctx, ev)
}
}
override def toString: String = s"new Set($dataType)"
}
/**
* Adds an item to a set.
* For performance, this expression mutates its input during evaluation.
* Note: this expression is internal and created only by the GeneratedAggregate,
* we don't need to do type check for it.
*/
case class AddItemToSet(item: Expression, set: Expression)
extends Expression with CodegenFallback {
override def children: Seq[Expression] = item :: set :: Nil
override def nullable: Boolean = set.nullable
override def dataType: DataType = set.dataType
override def eval(input: InternalRow): Any = {
val itemEval = item.eval(input)
val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
if (itemEval != null) {
if (setEval != null) {
setEval.add(itemEval)
setEval
} else {
null
}
} else {
setEval
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
elementType match {
case IntegerType | LongType =>
val itemEval = item.gen(ctx)
val setEval = set.gen(ctx)
val htype = ctx.javaType(dataType)
ev.isNull = "false"
ev.value = setEval.value
itemEval.code + setEval.code + s"""
if (!${itemEval.isNull} && !${setEval.isNull}) {
(($htype)${setEval.value}).add(${itemEval.value});
}
"""
case _ => super.genCode(ctx, ev)
}
}
override def toString: String = s"$set += $item"
}
/**
* Combines the elements of two sets.
* For performance, this expression mutates its left input set during evaluation.
* Note: this expression is internal and created only by the GeneratedAggregate,
* we don't need to do type check for it.
*/
case class CombineSets(left: Expression, right: Expression)
extends BinaryExpression with CodegenFallback {
override def nullable: Boolean = left.nullable
override def dataType: DataType = left.dataType
override def eval(input: InternalRow): Any = {
val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]]
if(leftEval != null) {
val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]]
if (rightEval != null) {
val iterator = rightEval.iterator
while(iterator.hasNext) {
val rightValue = iterator.next()
leftEval.add(rightValue)
}
}
leftEval
} else {
null
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
elementType match {
case IntegerType | LongType =>
val leftEval = left.gen(ctx)
val rightEval = right.gen(ctx)
val htype = ctx.javaType(dataType)
ev.isNull = leftEval.isNull
ev.value = leftEval.value
leftEval.code + rightEval.code + s"""
if (!${leftEval.isNull} && !${rightEval.isNull}) {
${leftEval.value}.union((${htype})${rightEval.value});
}
"""
case _ => super.genCode(ctx, ev)
}
}
}
/**
* Returns the number of elements in the input set.
* Note: this expression is internal and created only by the GeneratedAggregate,
* we don't need to do type check for it.
*/
case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback {
override def dataType: DataType = LongType
protected override def nullSafeEval(input: Any): Any =
input.asInstanceOf[OpenHashSet[Any]].size.toLong
override def toString: String = s"$child.count()"
}

View file

@ -22,19 +22,16 @@ import java.util.{HashMap => JavaHashMap}
import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.twitter.chill.ResourcePool
import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.{SparkConf, SparkEnv}
private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
override def newKryo(): Kryo = {
val kryo = super.newKryo()
@ -43,16 +40,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer)
kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer)
// Specific hashsets must come first TODO: Move to core.
kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer)
kryo.register(classOf[LongHashSet], new LongHashSetSerializer)
kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]],
new OpenHashSetSerializer)
kryo.register(classOf[Decimal])
kryo.register(classOf[JavaHashMap[_, _]])
@ -62,7 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
}
private[execution] class KryoResourcePool(size: Int)
extends ResourcePool[SerializerInstance](size) {
extends ResourcePool[SerializerInstance](size) {
val ser: SparkSqlSerializer = {
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
@ -116,92 +106,3 @@ private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] {
new java.math.BigDecimal(input.readString())
}
}
private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) {
val bytes = hyperLogLog.getBytes()
output.writeInt(bytes.length)
output.writeBytes(bytes)
}
def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = {
val length = input.readInt()
val bytes = input.readBytes(length)
HyperLogLog.Builder.build(bytes)
}
}
private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) {
val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
output.writeInt(hs.size)
val iterator = hs.iterator
while(iterator.hasNext) {
val row = iterator.next()
rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values)
}
}
def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = {
val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
val numItems = input.readInt()
val set = new OpenHashSet[Any](numItems + 1)
var i = 0
while (i < numItems) {
val row =
new GenericInternalRow(rowSerializer.read(
kryo,
input,
classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]])
set.add(row)
i += 1
}
set
}
}
private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] {
def write(kryo: Kryo, output: Output, hs: IntegerHashSet) {
output.writeInt(hs.size)
val iterator = hs.iterator
while(iterator.hasNext) {
val value: Int = iterator.next()
output.writeInt(value)
}
}
def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = {
val numItems = input.readInt()
val set = new IntegerHashSet
var i = 0
while (i < numItems) {
val value = input.readInt()
set.add(value)
i += 1
}
set
}
}
private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] {
def write(kryo: Kryo, output: Output, hs: LongHashSet) {
output.writeInt(hs.size)
val iterator = hs.iterator
while(iterator.hasNext) {
val value = iterator.next()
output.writeLong(value)
}
}
def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = {
val numItems = input.readInt()
val set = new LongHashSet
var i = 0
while (i < numItems) {
val value = input.readLong()
set.add(value)
i += 1
}
set
}
}

View file

@ -23,7 +23,6 @@ import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.OpenHashSetUDT
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@ -131,15 +130,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
}
test("OpenHashSetUDT") {
val openHashSetUDT = new OpenHashSetUDT(IntegerType)
val set = new OpenHashSet[Int]
(1 to 10).foreach(i => set.add(i))
val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set))
assert(actual.iterator.toSet === set.iterator.toSet)
}
test("UDTs with JSON") {
val data = Seq(
"{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}",
@ -163,7 +153,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
test("SPARK-10472 UserDefinedType.typeName") {
assert(IntegerType.typeName === "integer")
assert(new MyDenseVectorUDT().typeName === "mydensevector")
assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset")
}
test("Catalyst type converter null handling for UDTs") {