[SPARK-23599][SQL] Use RandomUUIDGenerator in Uuid expression

## What changes were proposed in this pull request?

As stated in Jira, there are problems with current `Uuid` expression which uses `java.util.UUID.randomUUID` for UUID generation.

This patch uses the newly added `RandomUUIDGenerator` for UUID generation. So we can make `Uuid` deterministic between retries.

## How was this patch tested?

Added unit tests.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #20861 from viirya/SPARK-23599-2.
This commit is contained in:
Liang-Chi Hsieh 2018-03-22 19:57:32 +01:00 committed by Herman van Hovell
parent 5c9eaa6b58
commit 4d37008c78
6 changed files with 136 additions and 9 deletions

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
@ -177,6 +178,7 @@ class Analyzer(
TimeWindowing ::
ResolveInlineTables(conf) ::
ResolveTimeZone(conf) ::
ResolvedUuidExpressions ::
TypeCoercion.typeCoercionRules(conf) ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
@ -1994,6 +1996,20 @@ class Analyzer(
}
}
/**
* Set the seed for random number generation in Uuid expressions.
*/
object ResolvedUuidExpressions extends Rule[LogicalPlan] {
private lazy val random = new Random()
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if p.resolved => p
case p => p transformExpressionsUp {
case Uuid(None) => Uuid(Some(random.nextLong()))
}
}
}
/**
* Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
* null check. When user defines a UDF with primitive parameters, there is no way to tell if the

View file

@ -21,6 +21,7 @@ import java.util.UUID
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -122,18 +123,33 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable {
46707d92-02f4-4817-8116-a4c3b23e6266
""")
// scalastyle:on line.size.limit
case class Uuid() extends LeafExpression {
case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic {
override lazy val deterministic: Boolean = false
def this() = this(None)
override lazy val resolved: Boolean = randomSeed.isDefined
override def nullable: Boolean = false
override def dataType: DataType = StringType
override def eval(input: InternalRow): Any = UTF8String.fromString(UUID.randomUUID().toString)
@transient private[this] var randomGenerator: RandomUUIDGenerator = _
override protected def initializeInternal(partitionIndex: Int): Unit =
randomGenerator = RandomUUIDGenerator(randomSeed.get + partitionIndex)
override protected def evalInternal(input: InternalRow): Any =
randomGenerator.getNextUUIDUTF8String()
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.copy(code = s"final UTF8String ${ev.value} = " +
s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false")
val randomGen = ctx.freshName("randomGen")
ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen,
forceInline = true,
useFreshName = false)
ctx.addPartitionInitializationStatement(s"$randomGen = " +
"new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" +
s"${randomSeed.get}L + partitionIndex);")
ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
isNull = "false")
}
}

View file

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
/**
* Test suite for resolving Uuid expressions.
*/
class ResolvedUuidExpressionsSuite extends AnalysisTest {
private lazy val a = 'a.int
private lazy val r = LocalRelation(a)
private lazy val uuid1 = Uuid().as('_uuid1)
private lazy val uuid2 = Uuid().as('_uuid2)
private lazy val uuid3 = Uuid().as('_uuid3)
private lazy val uuid1Ref = uuid1.toAttribute
private val analyzer = getAnalyzer(caseSensitive = true)
private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = {
plan.flatMap {
case p =>
p.expressions.flatMap(_.collect {
case u: Uuid => u
})
}
}
test("analyzed plan sets random seed for Uuid expression") {
val plan = r.select(a, uuid1)
val resolvedPlan = analyzer.executeAndCheck(plan)
getUuidExpressions(resolvedPlan).foreach { u =>
assert(u.resolved)
assert(u.randomSeed.isDefined)
}
}
test("Uuid expressions should have different random seeds") {
val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
val resolvedPlan = analyzer.executeAndCheck(plan)
assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3)
}
test("Different analyzed plans should have different random seeds in Uuids") {
val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
val resolvedPlan1 = analyzer.executeAndCheck(plan)
val resolvedPlan2 = analyzer.executeAndCheck(plan)
val uuids1 = getUuidExpressions(resolvedPlan1)
val uuids2 = getUuidExpressions(resolvedPlan2)
assert(uuids1.distinct.length == 3)
assert(uuids2.distinct.length == 3)
assert(uuids1.intersect(uuids2).length == 0)
}
}

View file

@ -176,7 +176,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
}
}
private def evaluateWithGeneratedMutableProjection(
protected def evaluateWithGeneratedMutableProjection(
expression: Expression,
inputRow: InternalRow = EmptyRow): Any = {
val plan = generateProject(
@ -220,7 +220,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
}
}
private def evaluateWithUnsafeProjection(
protected def evaluateWithUnsafeProjection(
expression: Expression,
inputRow: InternalRow = EmptyRow,
factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = {
@ -233,6 +233,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
Alias(expression, s"Optimized($expression)2")() :: Nil),
expression)
plan.initialize(0)
plan(inputRow)
}

View file

@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import java.io.PrintStream
import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
@ -42,8 +44,21 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("uuid") {
checkEvaluation(Length(Uuid()), 36)
assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid()))
checkEvaluation(Length(Uuid(Some(0))), 36)
val r = new Random()
val seed1 = Some(r.nextLong())
assert(evaluateWithoutCodegen(Uuid(seed1)) === evaluateWithoutCodegen(Uuid(seed1)))
assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) ===
evaluateWithGeneratedMutableProjection(Uuid(seed1)))
assert(evaluateWithUnsafeProjection(Uuid(seed1)) ===
evaluateWithUnsafeProjection(Uuid(seed1)))
val seed2 = Some(r.nextLong())
assert(evaluateWithoutCodegen(Uuid(seed1)) !== evaluateWithoutCodegen(Uuid(seed2)))
assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) !==
evaluateWithGeneratedMutableProjection(Uuid(seed2)))
assert(evaluateWithUnsafeProjection(Uuid(seed1)) !==
evaluateWithUnsafeProjection(Uuid(seed2)))
}
test("PrintToStderr") {

View file

@ -28,6 +28,7 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@ -2264,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(df, Row(0, 10) :: Nil)
assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
}
test("Uuid expressions should produce same results at retries in the same DataFrame") {
val df = spark.range(1).select($"id", new Column(Uuid()))
checkAnswer(df, df.collect())
}
}