[SPARK-35541][SQL] Simplify OptimizeSkewedJoin

### What changes were proposed in this pull request?

Various small code simplification/cleanup for OptimizeSkewedJoin

### Why are the changes needed?

code refactor

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

existing tests

Closes #32685 from cloud-fan/skew-join.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
Wenchen Fan 2021-05-27 09:17:28 -07:00 committed by Liang-Chi Hsieh
parent f98a063a4b
commit 29ed1a2de4

View file

@ -55,17 +55,14 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
private val ensureRequirements = EnsureRequirements
private val supportedJoinTypes =
Inner :: Cross :: LeftSemi :: LeftAnti :: LeftOuter :: RightOuter :: Nil
/**
* A partition is considered as a skewed partition if its size is larger than the median
* partition size * SKEW_JOIN_SKEWED_PARTITION_FACTOR and also larger than
* SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.
* SKEW_JOIN_SKEWED_PARTITION_THRESHOLD. Thus we pick the larger one as the skew threshold.
*/
private def isSkewed(size: Long, medianSize: Long): Boolean = {
size > medianSize * conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR) &&
size > conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD)
def getSkewThreshold(medianSize: Long): Long = {
conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD).max(
medianSize * conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR))
}
private def medianSize(sizes: Array[Long]): Long = {
@ -83,9 +80,9 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
* to split skewed partitions is the average size of non-skewed partition, or the
* advisory partition size if avg size is smaller than it.
*/
private def targetSize(sizes: Array[Long], medianSize: Long): Long = {
private def targetSize(sizes: Array[Long], skewThreshold: Long): Long = {
val advisorySize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES)
val nonSkewSizes = sizes.filterNot(isSkewed(_, medianSize))
val nonSkewSizes = sizes.filter(_ <= skewThreshold)
if (nonSkewSizes.isEmpty) {
advisorySize
} else {
@ -158,6 +155,10 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
left: ShuffleQueryStageExec,
right: ShuffleQueryStageExec,
joinType: JoinType): Option[(SparkPlan, SparkPlan)] = {
val canSplitLeft = canSplitLeftSide(joinType)
val canSplitRight = canSplitRightSide(joinType)
if (!canSplitLeft && !canSplitRight) return None
val leftSizes = left.mapStats.get.bytesByPartitionId
val rightSizes = right.mapStats.get.bytesByPartitionId
assert(leftSizes.length == rightSizes.length)
@ -174,10 +175,10 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
|${getSizeInfo(rightMedSize, rightSizes)}
""".stripMargin)
val canSplitLeft = canSplitLeftSide(joinType)
val canSplitRight = canSplitRightSide(joinType)
val leftTargetSize = targetSize(leftSizes, leftMedSize)
val rightTargetSize = targetSize(rightSizes, rightMedSize)
val leftSkewThreshold = getSkewThreshold(leftMedSize)
val rightSkewThreshold = getSkewThreshold(rightMedSize)
val leftTargetSize = targetSize(leftSizes, leftSkewThreshold)
val rightTargetSize = targetSize(rightSizes, rightSkewThreshold)
val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
@ -185,9 +186,9 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
var numSkewedRight = 0
for (partitionIndex <- 0 until numPartitions) {
val leftSize = leftSizes(partitionIndex)
val isLeftSkew = isSkewed(leftSize, leftMedSize) && canSplitLeft
val isLeftSkew = canSplitLeft && leftSize > leftSkewThreshold
val rightSize = rightSizes(partitionIndex)
val isRightSkew = isSkewed(rightSize, rightMedSize) && canSplitRight
val isRightSkew = canSplitRight && rightSize > rightSkewThreshold
val noSkewPartitionSpec = Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1))
val leftParts = if (isLeftSkew) {
@ -238,8 +239,7 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
case smj @ SortMergeJoinExec(_, _, joinType, _,
s1 @ SortExec(_, _, ShuffleStage(left: ShuffleQueryStageExec), _),
s2 @ SortExec(_, _, ShuffleStage(right: ShuffleQueryStageExec), _), isSkewJoin)
if !isSkewJoin && supportedJoinTypes.contains(joinType) =>
s2 @ SortExec(_, _, ShuffleStage(right: ShuffleQueryStageExec), _), false) =>
val newChildren = tryOptimizeJoinChildren(left, right, joinType)
if (newChildren.isDefined) {
val (newLeft, newRight) = newChildren.get
@ -251,8 +251,7 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
case shj @ ShuffledHashJoinExec(_, _, joinType, _, _,
ShuffleStage(left: ShuffleQueryStageExec),
ShuffleStage(right: ShuffleQueryStageExec), isSkewJoin)
if !isSkewJoin && supportedJoinTypes.contains(joinType) =>
ShuffleStage(right: ShuffleQueryStageExec), false) =>
val newChildren = tryOptimizeJoinChildren(left, right, joinType)
if (newChildren.isDefined) {
val (newLeft, newRight) = newChildren.get