[SPARK-17495][SQL] Add more tests for hive hash

## What changes were proposed in this pull request?

This PR adds tests hive-hash by comparing the outputs generated against Hive 1.2.1. Following datatypes are covered by this PR:
- null
- boolean
- byte
- short
- int
- long
- float
- double
- string
- array
- map
- struct

Datatypes that I have _NOT_ covered but I will work on separately are:
- Decimal (handled separately in https://github.com/apache/spark/pull/17056)
- TimestampType
- DateType
- CalendarIntervalType

## How was this patch tested?

NA

Author: Tejas Patil <tejasp@fb.com>

Closes #17049 from tejasapatil/SPARK-17495_remaining_types.
This commit is contained in:
Tejas Patil 2017-02-24 09:46:42 -08:00 committed by Reynold Xin
parent a920a43694
commit 3e40f6c3d6
3 changed files with 252 additions and 8 deletions

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.unsafe.Platform;
/**
* Simulates Hive's hashing function at
* Simulates Hive's hashing function from Hive v1.2.1
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
*/
public class HiveHasher {

View file

@ -573,10 +573,9 @@ object XxHash64Function extends InterpretedHashFunction {
}
}
/**
* Simulates Hive's hashing function at
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive
* Simulates Hive's hashing function from Hive v1.2.1 at
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
*
* We should use this hash function for both shuffle and bucket of Hive tables, so that
* we can guarantee shuffle and bucketing have same data distribution
@ -595,7 +594,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
override protected def hasherClassName: String = classOf[HiveHasher].getName
override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = {
HiveHashFunction.hash(value, dataType, seed).toInt
HiveHashFunction.hash(value, dataType, this.seed).toInt
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@ -781,12 +780,12 @@ object HiveHashFunction extends InterpretedHashFunction {
var i = 0
val length = struct.numFields
while (i < length) {
result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt
result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt
i += 1
}
result
case _ => super.hash(value, dataType, seed)
case _ => super.hash(value, dataType, 0)
}
}
}

View file

@ -19,16 +19,20 @@ package org.apache.spark.sql.catalyst.expressions
import java.nio.charset.StandardCharsets
import scala.collection.mutable.ArrayBuffer
import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{ArrayType, StructType, _}
import org.apache.spark.unsafe.types.UTF8String
class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val random = new scala.util.Random
test("md5") {
checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))),
@ -71,6 +75,247 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}
def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
// Note : All expected hashes need to be computed using Hive 1.2.1
val actual = HiveHashFunction.hash(input, dataType, seed = 0)
withClue(s"hash mismatch for input = `$input` of type `$dataType`.") {
assert(actual == expected)
}
}
def checkHiveHashForIntegralType(dataType: DataType): Unit = {
// corner cases
checkHiveHash(null, dataType, 0)
checkHiveHash(1, dataType, 1)
checkHiveHash(0, dataType, 0)
checkHiveHash(-1, dataType, -1)
checkHiveHash(Int.MaxValue, dataType, Int.MaxValue)
checkHiveHash(Int.MinValue, dataType, Int.MinValue)
// random values
for (_ <- 0 until 10) {
val input = random.nextInt()
checkHiveHash(input, dataType, input)
}
}
test("hive-hash for null") {
checkHiveHash(null, NullType, 0)
}
test("hive-hash for boolean") {
checkHiveHash(true, BooleanType, 1)
checkHiveHash(false, BooleanType, 0)
}
test("hive-hash for byte") {
checkHiveHashForIntegralType(ByteType)
}
test("hive-hash for short") {
checkHiveHashForIntegralType(ShortType)
}
test("hive-hash for int") {
checkHiveHashForIntegralType(IntegerType)
}
test("hive-hash for long") {
checkHiveHash(1L, LongType, 1L)
checkHiveHash(0L, LongType, 0L)
checkHiveHash(-1L, LongType, 0L)
checkHiveHash(Long.MaxValue, LongType, -2147483648)
// Hive's fails to parse this.. but the hashing function itself can handle this input
checkHiveHash(Long.MinValue, LongType, -2147483648)
for (_ <- 0 until 10) {
val input = random.nextLong()
checkHiveHash(input, LongType, ((input >>> 32) ^ input).toInt)
}
}
test("hive-hash for float") {
checkHiveHash(0F, FloatType, 0)
checkHiveHash(0.0F, FloatType, 0)
checkHiveHash(1.1F, FloatType, 1066192077L)
checkHiveHash(-1.1F, FloatType, -1081291571)
checkHiveHash(99999999.99999999999F, FloatType, 1287568416L)
checkHiveHash(Float.MaxValue, FloatType, 2139095039)
checkHiveHash(Float.MinValue, FloatType, -8388609)
}
test("hive-hash for double") {
checkHiveHash(0, DoubleType, 0)
checkHiveHash(0.0, DoubleType, 0)
checkHiveHash(1.1, DoubleType, -1503133693)
checkHiveHash(-1.1, DoubleType, 644349955)
checkHiveHash(1000000000.000001, DoubleType, 1104006509)
checkHiveHash(1000000000.0000000000000000000000001, DoubleType, 1104006501)
checkHiveHash(9999999999999999999.9999999999999999999, DoubleType, 594568676)
checkHiveHash(Double.MaxValue, DoubleType, -2146435072)
checkHiveHash(Double.MinValue, DoubleType, 1048576)
}
test("hive-hash for string") {
checkHiveHash(UTF8String.fromString("apache spark"), StringType, 1142704523L)
checkHiveHash(UTF8String.fromString("!@#$%^&*()_+=-"), StringType, -613724358L)
checkHiveHash(UTF8String.fromString("abcdefghijklmnopqrstuvwxyz"), StringType, 958031277L)
checkHiveHash(UTF8String.fromString("AbCdEfGhIjKlMnOpQrStUvWxYz012"), StringType, -648013852L)
// scalastyle:off nonascii
checkHiveHash(UTF8String.fromString("数据砖头"), StringType, -898686242L)
checkHiveHash(UTF8String.fromString("नमस्ते"), StringType, 2006045948L)
// scalastyle:on nonascii
}
test("hive-hash for array") {
// empty array
checkHiveHash(
input = new GenericArrayData(Array[Int]()),
dataType = ArrayType(IntegerType, containsNull = false),
expected = 0)
// basic case
checkHiveHash(
input = new GenericArrayData(Array(1, 10000, Int.MaxValue)),
dataType = ArrayType(IntegerType, containsNull = false),
expected = -2147172688L)
// with negative values
checkHiveHash(
input = new GenericArrayData(Array(-1L, 0L, 999L, Int.MinValue.toLong)),
dataType = ArrayType(LongType, containsNull = false),
expected = -2147452680L)
// with nulls only
val arrayTypeWithNull = ArrayType(IntegerType, containsNull = true)
checkHiveHash(
input = new GenericArrayData(Array(null, null)),
dataType = arrayTypeWithNull,
expected = 0)
// mix with null
checkHiveHash(
input = new GenericArrayData(Array(-12221, 89, null, 767)),
dataType = arrayTypeWithNull,
expected = -363989515)
// nested with array
checkHiveHash(
input = new GenericArrayData(
Array(
new GenericArrayData(Array(1234L, -9L, 67L)),
new GenericArrayData(Array(null, null)),
new GenericArrayData(Array(55L, -100L, -2147452680L))
)),
dataType = ArrayType(ArrayType(LongType)),
expected = -1007531064)
// nested with map
checkHiveHash(
input = new GenericArrayData(
Array(
new ArrayBasedMapData(
new GenericArrayData(Array(-99, 1234)),
new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
new ArrayBasedMapData(
new GenericArrayData(Array(67)),
new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
)),
dataType = ArrayType(MapType(IntegerType, StringType)),
expected = 1139205955)
}
test("hive-hash for map") {
val mapType = MapType(IntegerType, StringType)
// empty map
checkHiveHash(
input = new ArrayBasedMapData(new GenericArrayData(Array()), new GenericArrayData(Array())),
dataType = mapType,
expected = 0)
// basic case
checkHiveHash(
input = new ArrayBasedMapData(
new GenericArrayData(Array(1, 2)),
new GenericArrayData(Array(UTF8String.fromString("foo"), UTF8String.fromString("bar")))),
dataType = mapType,
expected = 198872)
// with null value
checkHiveHash(
input = new ArrayBasedMapData(
new GenericArrayData(Array(55, -99)),
new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))),
dataType = mapType,
expected = 1142704473)
// nesting (only values can be nested as keys have to be primitive datatype)
val nestedMapType = MapType(IntegerType, MapType(IntegerType, StringType))
checkHiveHash(
input = new ArrayBasedMapData(
new GenericArrayData(Array(1, -100)),
new GenericArrayData(
Array(
new ArrayBasedMapData(
new GenericArrayData(Array(-99, 1234)),
new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
new ArrayBasedMapData(
new GenericArrayData(Array(67)),
new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
))),
dataType = nestedMapType,
expected = -1142817416)
}
test("hive-hash for struct") {
// basic
val row = new GenericInternalRow(Array[Any](1, 2, 3))
checkHiveHash(
input = row,
dataType =
new StructType()
.add("col1", IntegerType)
.add("col2", IntegerType)
.add("col3", IntegerType),
expected = 1026)
// mix of several datatypes
val structType = new StructType()
.add("null", NullType)
.add("boolean", BooleanType)
.add("byte", ByteType)
.add("short", ShortType)
.add("int", IntegerType)
.add("long", LongType)
.add("arrayOfString", arrayOfString)
.add("mapOfString", mapOfString)
val rowValues = new ArrayBuffer[Any]()
rowValues += null
rowValues += true
rowValues += 1
rowValues += 2
rowValues += Int.MaxValue
rowValues += Long.MinValue
rowValues += new GenericArrayData(Array(
UTF8String.fromString("apache spark"),
UTF8String.fromString("hello world")
))
rowValues += new ArrayBasedMapData(
new GenericArrayData(Array(UTF8String.fromString("project"), UTF8String.fromString("meta"))),
new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))
)
val row2 = new GenericInternalRow(rowValues.toArray)
checkHiveHash(
input = row2,
dataType = structType,
expected = -2119012447)
}
private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)