[SPARK-17495][SQL] Add more tests for hive hash
## What changes were proposed in this pull request? This PR adds tests hive-hash by comparing the outputs generated against Hive 1.2.1. Following datatypes are covered by this PR: - null - boolean - byte - short - int - long - float - double - string - array - map - struct Datatypes that I have _NOT_ covered but I will work on separately are: - Decimal (handled separately in https://github.com/apache/spark/pull/17056) - TimestampType - DateType - CalendarIntervalType ## How was this patch tested? NA Author: Tejas Patil <tejasp@fb.com> Closes #17049 from tejasapatil/SPARK-17495_remaining_types.
This commit is contained in:
parent
a920a43694
commit
3e40f6c3d6
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions;
|
|||
import org.apache.spark.unsafe.Platform;
|
||||
|
||||
/**
|
||||
* Simulates Hive's hashing function at
|
||||
* Simulates Hive's hashing function from Hive v1.2.1
|
||||
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
|
||||
*/
|
||||
public class HiveHasher {
|
||||
|
|
|
@ -573,10 +573,9 @@ object XxHash64Function extends InterpretedHashFunction {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Simulates Hive's hashing function at
|
||||
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive
|
||||
* Simulates Hive's hashing function from Hive v1.2.1 at
|
||||
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
|
||||
*
|
||||
* We should use this hash function for both shuffle and bucket of Hive tables, so that
|
||||
* we can guarantee shuffle and bucketing have same data distribution
|
||||
|
@ -595,7 +594,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
|||
override protected def hasherClassName: String = classOf[HiveHasher].getName
|
||||
|
||||
override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = {
|
||||
HiveHashFunction.hash(value, dataType, seed).toInt
|
||||
HiveHashFunction.hash(value, dataType, this.seed).toInt
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
|
@ -781,12 +780,12 @@ object HiveHashFunction extends InterpretedHashFunction {
|
|||
var i = 0
|
||||
val length = struct.numFields
|
||||
while (i < length) {
|
||||
result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt
|
||||
result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt
|
||||
i += 1
|
||||
}
|
||||
result
|
||||
|
||||
case _ => super.hash(value, dataType, seed)
|
||||
case _ => super.hash(value, dataType, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,16 +19,20 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.commons.codec.digest.DigestUtils
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.{RandomDataGenerator, Row}
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
|
||||
import org.apache.spark.sql.types.{ArrayType, StructType, _}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||
val random = new scala.util.Random
|
||||
|
||||
test("md5") {
|
||||
checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))),
|
||||
|
@ -71,6 +75,247 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
|
||||
}
|
||||
|
||||
|
||||
def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
|
||||
// Note : All expected hashes need to be computed using Hive 1.2.1
|
||||
val actual = HiveHashFunction.hash(input, dataType, seed = 0)
|
||||
|
||||
withClue(s"hash mismatch for input = `$input` of type `$dataType`.") {
|
||||
assert(actual == expected)
|
||||
}
|
||||
}
|
||||
|
||||
def checkHiveHashForIntegralType(dataType: DataType): Unit = {
|
||||
// corner cases
|
||||
checkHiveHash(null, dataType, 0)
|
||||
checkHiveHash(1, dataType, 1)
|
||||
checkHiveHash(0, dataType, 0)
|
||||
checkHiveHash(-1, dataType, -1)
|
||||
checkHiveHash(Int.MaxValue, dataType, Int.MaxValue)
|
||||
checkHiveHash(Int.MinValue, dataType, Int.MinValue)
|
||||
|
||||
// random values
|
||||
for (_ <- 0 until 10) {
|
||||
val input = random.nextInt()
|
||||
checkHiveHash(input, dataType, input)
|
||||
}
|
||||
}
|
||||
|
||||
test("hive-hash for null") {
|
||||
checkHiveHash(null, NullType, 0)
|
||||
}
|
||||
|
||||
test("hive-hash for boolean") {
|
||||
checkHiveHash(true, BooleanType, 1)
|
||||
checkHiveHash(false, BooleanType, 0)
|
||||
}
|
||||
|
||||
test("hive-hash for byte") {
|
||||
checkHiveHashForIntegralType(ByteType)
|
||||
}
|
||||
|
||||
test("hive-hash for short") {
|
||||
checkHiveHashForIntegralType(ShortType)
|
||||
}
|
||||
|
||||
test("hive-hash for int") {
|
||||
checkHiveHashForIntegralType(IntegerType)
|
||||
}
|
||||
|
||||
test("hive-hash for long") {
|
||||
checkHiveHash(1L, LongType, 1L)
|
||||
checkHiveHash(0L, LongType, 0L)
|
||||
checkHiveHash(-1L, LongType, 0L)
|
||||
checkHiveHash(Long.MaxValue, LongType, -2147483648)
|
||||
// Hive's fails to parse this.. but the hashing function itself can handle this input
|
||||
checkHiveHash(Long.MinValue, LongType, -2147483648)
|
||||
|
||||
for (_ <- 0 until 10) {
|
||||
val input = random.nextLong()
|
||||
checkHiveHash(input, LongType, ((input >>> 32) ^ input).toInt)
|
||||
}
|
||||
}
|
||||
|
||||
test("hive-hash for float") {
|
||||
checkHiveHash(0F, FloatType, 0)
|
||||
checkHiveHash(0.0F, FloatType, 0)
|
||||
checkHiveHash(1.1F, FloatType, 1066192077L)
|
||||
checkHiveHash(-1.1F, FloatType, -1081291571)
|
||||
checkHiveHash(99999999.99999999999F, FloatType, 1287568416L)
|
||||
checkHiveHash(Float.MaxValue, FloatType, 2139095039)
|
||||
checkHiveHash(Float.MinValue, FloatType, -8388609)
|
||||
}
|
||||
|
||||
test("hive-hash for double") {
|
||||
checkHiveHash(0, DoubleType, 0)
|
||||
checkHiveHash(0.0, DoubleType, 0)
|
||||
checkHiveHash(1.1, DoubleType, -1503133693)
|
||||
checkHiveHash(-1.1, DoubleType, 644349955)
|
||||
checkHiveHash(1000000000.000001, DoubleType, 1104006509)
|
||||
checkHiveHash(1000000000.0000000000000000000000001, DoubleType, 1104006501)
|
||||
checkHiveHash(9999999999999999999.9999999999999999999, DoubleType, 594568676)
|
||||
checkHiveHash(Double.MaxValue, DoubleType, -2146435072)
|
||||
checkHiveHash(Double.MinValue, DoubleType, 1048576)
|
||||
}
|
||||
|
||||
test("hive-hash for string") {
|
||||
checkHiveHash(UTF8String.fromString("apache spark"), StringType, 1142704523L)
|
||||
checkHiveHash(UTF8String.fromString("!@#$%^&*()_+=-"), StringType, -613724358L)
|
||||
checkHiveHash(UTF8String.fromString("abcdefghijklmnopqrstuvwxyz"), StringType, 958031277L)
|
||||
checkHiveHash(UTF8String.fromString("AbCdEfGhIjKlMnOpQrStUvWxYz012"), StringType, -648013852L)
|
||||
// scalastyle:off nonascii
|
||||
checkHiveHash(UTF8String.fromString("数据砖头"), StringType, -898686242L)
|
||||
checkHiveHash(UTF8String.fromString("नमस्ते"), StringType, 2006045948L)
|
||||
// scalastyle:on nonascii
|
||||
}
|
||||
|
||||
test("hive-hash for array") {
|
||||
// empty array
|
||||
checkHiveHash(
|
||||
input = new GenericArrayData(Array[Int]()),
|
||||
dataType = ArrayType(IntegerType, containsNull = false),
|
||||
expected = 0)
|
||||
|
||||
// basic case
|
||||
checkHiveHash(
|
||||
input = new GenericArrayData(Array(1, 10000, Int.MaxValue)),
|
||||
dataType = ArrayType(IntegerType, containsNull = false),
|
||||
expected = -2147172688L)
|
||||
|
||||
// with negative values
|
||||
checkHiveHash(
|
||||
input = new GenericArrayData(Array(-1L, 0L, 999L, Int.MinValue.toLong)),
|
||||
dataType = ArrayType(LongType, containsNull = false),
|
||||
expected = -2147452680L)
|
||||
|
||||
// with nulls only
|
||||
val arrayTypeWithNull = ArrayType(IntegerType, containsNull = true)
|
||||
checkHiveHash(
|
||||
input = new GenericArrayData(Array(null, null)),
|
||||
dataType = arrayTypeWithNull,
|
||||
expected = 0)
|
||||
|
||||
// mix with null
|
||||
checkHiveHash(
|
||||
input = new GenericArrayData(Array(-12221, 89, null, 767)),
|
||||
dataType = arrayTypeWithNull,
|
||||
expected = -363989515)
|
||||
|
||||
// nested with array
|
||||
checkHiveHash(
|
||||
input = new GenericArrayData(
|
||||
Array(
|
||||
new GenericArrayData(Array(1234L, -9L, 67L)),
|
||||
new GenericArrayData(Array(null, null)),
|
||||
new GenericArrayData(Array(55L, -100L, -2147452680L))
|
||||
)),
|
||||
dataType = ArrayType(ArrayType(LongType)),
|
||||
expected = -1007531064)
|
||||
|
||||
// nested with map
|
||||
checkHiveHash(
|
||||
input = new GenericArrayData(
|
||||
Array(
|
||||
new ArrayBasedMapData(
|
||||
new GenericArrayData(Array(-99, 1234)),
|
||||
new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
|
||||
new ArrayBasedMapData(
|
||||
new GenericArrayData(Array(67)),
|
||||
new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
|
||||
)),
|
||||
dataType = ArrayType(MapType(IntegerType, StringType)),
|
||||
expected = 1139205955)
|
||||
}
|
||||
|
||||
test("hive-hash for map") {
|
||||
val mapType = MapType(IntegerType, StringType)
|
||||
|
||||
// empty map
|
||||
checkHiveHash(
|
||||
input = new ArrayBasedMapData(new GenericArrayData(Array()), new GenericArrayData(Array())),
|
||||
dataType = mapType,
|
||||
expected = 0)
|
||||
|
||||
// basic case
|
||||
checkHiveHash(
|
||||
input = new ArrayBasedMapData(
|
||||
new GenericArrayData(Array(1, 2)),
|
||||
new GenericArrayData(Array(UTF8String.fromString("foo"), UTF8String.fromString("bar")))),
|
||||
dataType = mapType,
|
||||
expected = 198872)
|
||||
|
||||
// with null value
|
||||
checkHiveHash(
|
||||
input = new ArrayBasedMapData(
|
||||
new GenericArrayData(Array(55, -99)),
|
||||
new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))),
|
||||
dataType = mapType,
|
||||
expected = 1142704473)
|
||||
|
||||
// nesting (only values can be nested as keys have to be primitive datatype)
|
||||
val nestedMapType = MapType(IntegerType, MapType(IntegerType, StringType))
|
||||
checkHiveHash(
|
||||
input = new ArrayBasedMapData(
|
||||
new GenericArrayData(Array(1, -100)),
|
||||
new GenericArrayData(
|
||||
Array(
|
||||
new ArrayBasedMapData(
|
||||
new GenericArrayData(Array(-99, 1234)),
|
||||
new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
|
||||
new ArrayBasedMapData(
|
||||
new GenericArrayData(Array(67)),
|
||||
new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
|
||||
))),
|
||||
dataType = nestedMapType,
|
||||
expected = -1142817416)
|
||||
}
|
||||
|
||||
test("hive-hash for struct") {
|
||||
// basic
|
||||
val row = new GenericInternalRow(Array[Any](1, 2, 3))
|
||||
checkHiveHash(
|
||||
input = row,
|
||||
dataType =
|
||||
new StructType()
|
||||
.add("col1", IntegerType)
|
||||
.add("col2", IntegerType)
|
||||
.add("col3", IntegerType),
|
||||
expected = 1026)
|
||||
|
||||
// mix of several datatypes
|
||||
val structType = new StructType()
|
||||
.add("null", NullType)
|
||||
.add("boolean", BooleanType)
|
||||
.add("byte", ByteType)
|
||||
.add("short", ShortType)
|
||||
.add("int", IntegerType)
|
||||
.add("long", LongType)
|
||||
.add("arrayOfString", arrayOfString)
|
||||
.add("mapOfString", mapOfString)
|
||||
|
||||
val rowValues = new ArrayBuffer[Any]()
|
||||
rowValues += null
|
||||
rowValues += true
|
||||
rowValues += 1
|
||||
rowValues += 2
|
||||
rowValues += Int.MaxValue
|
||||
rowValues += Long.MinValue
|
||||
rowValues += new GenericArrayData(Array(
|
||||
UTF8String.fromString("apache spark"),
|
||||
UTF8String.fromString("hello world")
|
||||
))
|
||||
rowValues += new ArrayBasedMapData(
|
||||
new GenericArrayData(Array(UTF8String.fromString("project"), UTF8String.fromString("meta"))),
|
||||
new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))
|
||||
)
|
||||
|
||||
val row2 = new GenericInternalRow(rowValues.toArray)
|
||||
checkHiveHash(
|
||||
input = row2,
|
||||
dataType = structType,
|
||||
expected = -2119012447)
|
||||
}
|
||||
|
||||
private val structOfString = new StructType().add("str", StringType)
|
||||
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
|
||||
private val arrayOfString = ArrayType(StringType)
|
||||
|
|
Loading…
Reference in a new issue