[SPARK-11645][SQL] Remove OpenHashSet for the old aggregate.
Author: Reynold Xin <rxin@databricks.com> Closes #9621 from rxin/SPARK-11645.
This commit is contained in:
parent
df97df2b39
commit
a9a6b80c71
|
@ -33,10 +33,6 @@ import org.apache.spark.unsafe.Platform
|
|||
import org.apache.spark.unsafe.types._
|
||||
|
||||
|
||||
// These classes are here to avoid issues with serialization and integration with quasiquotes.
|
||||
class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
|
||||
class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
|
||||
|
||||
/**
|
||||
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
|
||||
*
|
||||
|
@ -205,8 +201,6 @@ class CodeGenContext {
|
|||
case _: StructType => "InternalRow"
|
||||
case _: ArrayType => "ArrayData"
|
||||
case _: MapType => "MapData"
|
||||
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
|
||||
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
|
||||
case udt: UserDefinedType[_] => javaType(udt.sqlType)
|
||||
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
|
||||
case ObjectType(cls) => cls.getName
|
||||
|
|
|
@ -39,7 +39,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
|
||||
case t: ArrayType if canSupport(t.elementType) => true
|
||||
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
|
||||
case dt: OpenHashSetUDT => false // it's not a standard UDT
|
||||
case udt: UserDefinedType[_] => canSupport(udt.sqlType)
|
||||
case _ => false
|
||||
}
|
||||
|
@ -309,13 +308,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
in.map(BindReferences.bindReference(_, inputSchema))
|
||||
|
||||
def generate(
|
||||
expressions: Seq[Expression],
|
||||
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
|
||||
expressions: Seq[Expression],
|
||||
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
|
||||
create(canonicalize(expressions), subexpressionEliminationEnabled)
|
||||
}
|
||||
|
||||
protected def create(expressions: Seq[Expression]): UnsafeProjection = {
|
||||
create(expressions, false)
|
||||
create(expressions, subexpressionEliminationEnabled = false)
|
||||
}
|
||||
|
||||
private def create(
|
||||
|
|
|
@ -1,194 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.collection.OpenHashSet
|
||||
|
||||
/** The data type for expressions returning an OpenHashSet as the result. */
|
||||
private[sql] class OpenHashSetUDT(
|
||||
val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] {
|
||||
|
||||
override def sqlType: DataType = ArrayType(elementType)
|
||||
|
||||
/** Since we are using OpenHashSet internally, usually it will not be called. */
|
||||
override def serialize(obj: Any): Seq[Any] = {
|
||||
obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq
|
||||
}
|
||||
|
||||
/** Since we are using OpenHashSet internally, usually it will not be called. */
|
||||
override def deserialize(datum: Any): OpenHashSet[Any] = {
|
||||
val iterator = datum.asInstanceOf[Seq[Any]].iterator
|
||||
val set = new OpenHashSet[Any]
|
||||
while(iterator.hasNext) {
|
||||
set.add(iterator.next())
|
||||
}
|
||||
|
||||
set
|
||||
}
|
||||
|
||||
override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]]
|
||||
|
||||
private[spark] override def asNullable: OpenHashSetUDT = this
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new set of the specified type
|
||||
*/
|
||||
case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback {
|
||||
|
||||
override def nullable: Boolean = false
|
||||
|
||||
override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType)
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
new OpenHashSet[Any]()
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
elementType match {
|
||||
case IntegerType | LongType =>
|
||||
ev.isNull = "false"
|
||||
s"""
|
||||
${ctx.javaType(dataType)} ${ev.value} = new ${ctx.javaType(dataType)}();
|
||||
"""
|
||||
case _ => super.genCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
|
||||
override def toString: String = s"new Set($dataType)"
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds an item to a set.
|
||||
* For performance, this expression mutates its input during evaluation.
|
||||
* Note: this expression is internal and created only by the GeneratedAggregate,
|
||||
* we don't need to do type check for it.
|
||||
*/
|
||||
case class AddItemToSet(item: Expression, set: Expression)
|
||||
extends Expression with CodegenFallback {
|
||||
|
||||
override def children: Seq[Expression] = item :: set :: Nil
|
||||
|
||||
override def nullable: Boolean = set.nullable
|
||||
|
||||
override def dataType: DataType = set.dataType
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val itemEval = item.eval(input)
|
||||
val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
|
||||
|
||||
if (itemEval != null) {
|
||||
if (setEval != null) {
|
||||
setEval.add(itemEval)
|
||||
setEval
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} else {
|
||||
setEval
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
|
||||
elementType match {
|
||||
case IntegerType | LongType =>
|
||||
val itemEval = item.gen(ctx)
|
||||
val setEval = set.gen(ctx)
|
||||
val htype = ctx.javaType(dataType)
|
||||
|
||||
ev.isNull = "false"
|
||||
ev.value = setEval.value
|
||||
itemEval.code + setEval.code + s"""
|
||||
if (!${itemEval.isNull} && !${setEval.isNull}) {
|
||||
(($htype)${setEval.value}).add(${itemEval.value});
|
||||
}
|
||||
"""
|
||||
case _ => super.genCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
|
||||
override def toString: String = s"$set += $item"
|
||||
}
|
||||
|
||||
/**
|
||||
* Combines the elements of two sets.
|
||||
* For performance, this expression mutates its left input set during evaluation.
|
||||
* Note: this expression is internal and created only by the GeneratedAggregate,
|
||||
* we don't need to do type check for it.
|
||||
*/
|
||||
case class CombineSets(left: Expression, right: Expression)
|
||||
extends BinaryExpression with CodegenFallback {
|
||||
|
||||
override def nullable: Boolean = left.nullable
|
||||
override def dataType: DataType = left.dataType
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]]
|
||||
if(leftEval != null) {
|
||||
val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]]
|
||||
if (rightEval != null) {
|
||||
val iterator = rightEval.iterator
|
||||
while(iterator.hasNext) {
|
||||
val rightValue = iterator.next()
|
||||
leftEval.add(rightValue)
|
||||
}
|
||||
}
|
||||
leftEval
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
|
||||
elementType match {
|
||||
case IntegerType | LongType =>
|
||||
val leftEval = left.gen(ctx)
|
||||
val rightEval = right.gen(ctx)
|
||||
val htype = ctx.javaType(dataType)
|
||||
|
||||
ev.isNull = leftEval.isNull
|
||||
ev.value = leftEval.value
|
||||
leftEval.code + rightEval.code + s"""
|
||||
if (!${leftEval.isNull} && !${rightEval.isNull}) {
|
||||
${leftEval.value}.union((${htype})${rightEval.value});
|
||||
}
|
||||
"""
|
||||
case _ => super.genCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of elements in the input set.
|
||||
* Note: this expression is internal and created only by the GeneratedAggregate,
|
||||
* we don't need to do type check for it.
|
||||
*/
|
||||
case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback {
|
||||
|
||||
override def dataType: DataType = LongType
|
||||
|
||||
protected override def nullSafeEval(input: Any): Any =
|
||||
input.asInstanceOf[OpenHashSet[Any]].size.toLong
|
||||
|
||||
override def toString: String = s"$child.count()"
|
||||
}
|
|
@ -22,19 +22,16 @@ import java.util.{HashMap => JavaHashMap}
|
|||
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import com.clearspring.analytics.stream.cardinality.HyperLogLog
|
||||
import com.esotericsoftware.kryo.io.{Input, Output}
|
||||
import com.esotericsoftware.kryo.{Kryo, Serializer}
|
||||
import com.twitter.chill.ResourcePool
|
||||
|
||||
import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet}
|
||||
import org.apache.spark.sql.types.Decimal
|
||||
import org.apache.spark.util.MutablePair
|
||||
import org.apache.spark.util.collection.OpenHashSet
|
||||
import org.apache.spark.{SparkConf, SparkEnv}
|
||||
|
||||
|
||||
private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
|
||||
override def newKryo(): Kryo = {
|
||||
val kryo = super.newKryo()
|
||||
|
@ -43,16 +40,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
|
|||
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
|
||||
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow])
|
||||
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
|
||||
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
|
||||
new HyperLogLogSerializer)
|
||||
kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer)
|
||||
kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer)
|
||||
|
||||
// Specific hashsets must come first TODO: Move to core.
|
||||
kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer)
|
||||
kryo.register(classOf[LongHashSet], new LongHashSetSerializer)
|
||||
kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]],
|
||||
new OpenHashSetSerializer)
|
||||
kryo.register(classOf[Decimal])
|
||||
kryo.register(classOf[JavaHashMap[_, _]])
|
||||
|
||||
|
@ -62,7 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
|
|||
}
|
||||
|
||||
private[execution] class KryoResourcePool(size: Int)
|
||||
extends ResourcePool[SerializerInstance](size) {
|
||||
extends ResourcePool[SerializerInstance](size) {
|
||||
|
||||
val ser: SparkSqlSerializer = {
|
||||
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
|
||||
|
@ -116,92 +106,3 @@ private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] {
|
|||
new java.math.BigDecimal(input.readString())
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
|
||||
def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) {
|
||||
val bytes = hyperLogLog.getBytes()
|
||||
output.writeInt(bytes.length)
|
||||
output.writeBytes(bytes)
|
||||
}
|
||||
|
||||
def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = {
|
||||
val length = input.readInt()
|
||||
val bytes = input.readBytes(length)
|
||||
HyperLogLog.Builder.build(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
|
||||
def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) {
|
||||
val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
|
||||
output.writeInt(hs.size)
|
||||
val iterator = hs.iterator
|
||||
while(iterator.hasNext) {
|
||||
val row = iterator.next()
|
||||
rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values)
|
||||
}
|
||||
}
|
||||
|
||||
def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = {
|
||||
val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
|
||||
val numItems = input.readInt()
|
||||
val set = new OpenHashSet[Any](numItems + 1)
|
||||
var i = 0
|
||||
while (i < numItems) {
|
||||
val row =
|
||||
new GenericInternalRow(rowSerializer.read(
|
||||
kryo,
|
||||
input,
|
||||
classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]])
|
||||
set.add(row)
|
||||
i += 1
|
||||
}
|
||||
set
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] {
|
||||
def write(kryo: Kryo, output: Output, hs: IntegerHashSet) {
|
||||
output.writeInt(hs.size)
|
||||
val iterator = hs.iterator
|
||||
while(iterator.hasNext) {
|
||||
val value: Int = iterator.next()
|
||||
output.writeInt(value)
|
||||
}
|
||||
}
|
||||
|
||||
def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = {
|
||||
val numItems = input.readInt()
|
||||
val set = new IntegerHashSet
|
||||
var i = 0
|
||||
while (i < numItems) {
|
||||
val value = input.readInt()
|
||||
set.add(value)
|
||||
i += 1
|
||||
}
|
||||
set
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] {
|
||||
def write(kryo: Kryo, output: Output, hs: LongHashSet) {
|
||||
output.writeInt(hs.size)
|
||||
val iterator = hs.iterator
|
||||
while(iterator.hasNext) {
|
||||
val value = iterator.next()
|
||||
output.writeLong(value)
|
||||
}
|
||||
}
|
||||
|
||||
def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = {
|
||||
val numItems = input.readInt()
|
||||
val set = new LongHashSet
|
||||
var i = 0
|
||||
while (i < numItems) {
|
||||
val value = input.readLong()
|
||||
set.add(value)
|
||||
i += 1
|
||||
}
|
||||
set
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,6 @@ import scala.beans.{BeanInfo, BeanProperty}
|
|||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.CatalystTypeConverters
|
||||
import org.apache.spark.sql.catalyst.expressions.OpenHashSetUDT
|
||||
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
@ -131,15 +130,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
|
|||
df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
|
||||
}
|
||||
|
||||
test("OpenHashSetUDT") {
|
||||
val openHashSetUDT = new OpenHashSetUDT(IntegerType)
|
||||
val set = new OpenHashSet[Int]
|
||||
(1 to 10).foreach(i => set.add(i))
|
||||
|
||||
val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set))
|
||||
assert(actual.iterator.toSet === set.iterator.toSet)
|
||||
}
|
||||
|
||||
test("UDTs with JSON") {
|
||||
val data = Seq(
|
||||
"{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}",
|
||||
|
@ -163,7 +153,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
|
|||
test("SPARK-10472 UserDefinedType.typeName") {
|
||||
assert(IntegerType.typeName === "integer")
|
||||
assert(new MyDenseVectorUDT().typeName === "mydensevector")
|
||||
assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset")
|
||||
}
|
||||
|
||||
test("Catalyst type converter null handling for UDTs") {
|
||||
|
|
Loading…
Reference in a new issue