[SPARK-35541][SQL] Simplify OptimizeSkewedJoin
### What changes were proposed in this pull request? Various small code simplification/cleanup for OptimizeSkewedJoin ### Why are the changes needed? code refactor ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #32685 from cloud-fan/skew-join. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
parent
f98a063a4b
commit
29ed1a2de4
|
@ -55,17 +55,14 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
|
|||
|
||||
private val ensureRequirements = EnsureRequirements
|
||||
|
||||
private val supportedJoinTypes =
|
||||
Inner :: Cross :: LeftSemi :: LeftAnti :: LeftOuter :: RightOuter :: Nil
|
||||
|
||||
/**
|
||||
* A partition is considered as a skewed partition if its size is larger than the median
|
||||
* partition size * SKEW_JOIN_SKEWED_PARTITION_FACTOR and also larger than
|
||||
* SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.
|
||||
* SKEW_JOIN_SKEWED_PARTITION_THRESHOLD. Thus we pick the larger one as the skew threshold.
|
||||
*/
|
||||
private def isSkewed(size: Long, medianSize: Long): Boolean = {
|
||||
size > medianSize * conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR) &&
|
||||
size > conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD)
|
||||
def getSkewThreshold(medianSize: Long): Long = {
|
||||
conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD).max(
|
||||
medianSize * conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR))
|
||||
}
|
||||
|
||||
private def medianSize(sizes: Array[Long]): Long = {
|
||||
|
@ -83,9 +80,9 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
|
|||
* to split skewed partitions is the average size of non-skewed partition, or the
|
||||
* advisory partition size if avg size is smaller than it.
|
||||
*/
|
||||
private def targetSize(sizes: Array[Long], medianSize: Long): Long = {
|
||||
private def targetSize(sizes: Array[Long], skewThreshold: Long): Long = {
|
||||
val advisorySize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES)
|
||||
val nonSkewSizes = sizes.filterNot(isSkewed(_, medianSize))
|
||||
val nonSkewSizes = sizes.filter(_ <= skewThreshold)
|
||||
if (nonSkewSizes.isEmpty) {
|
||||
advisorySize
|
||||
} else {
|
||||
|
@ -158,6 +155,10 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
|
|||
left: ShuffleQueryStageExec,
|
||||
right: ShuffleQueryStageExec,
|
||||
joinType: JoinType): Option[(SparkPlan, SparkPlan)] = {
|
||||
val canSplitLeft = canSplitLeftSide(joinType)
|
||||
val canSplitRight = canSplitRightSide(joinType)
|
||||
if (!canSplitLeft && !canSplitRight) return None
|
||||
|
||||
val leftSizes = left.mapStats.get.bytesByPartitionId
|
||||
val rightSizes = right.mapStats.get.bytesByPartitionId
|
||||
assert(leftSizes.length == rightSizes.length)
|
||||
|
@ -174,10 +175,10 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
|
|||
|${getSizeInfo(rightMedSize, rightSizes)}
|
||||
""".stripMargin)
|
||||
|
||||
val canSplitLeft = canSplitLeftSide(joinType)
|
||||
val canSplitRight = canSplitRightSide(joinType)
|
||||
val leftTargetSize = targetSize(leftSizes, leftMedSize)
|
||||
val rightTargetSize = targetSize(rightSizes, rightMedSize)
|
||||
val leftSkewThreshold = getSkewThreshold(leftMedSize)
|
||||
val rightSkewThreshold = getSkewThreshold(rightMedSize)
|
||||
val leftTargetSize = targetSize(leftSizes, leftSkewThreshold)
|
||||
val rightTargetSize = targetSize(rightSizes, rightSkewThreshold)
|
||||
|
||||
val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
|
||||
val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
|
||||
|
@ -185,9 +186,9 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
|
|||
var numSkewedRight = 0
|
||||
for (partitionIndex <- 0 until numPartitions) {
|
||||
val leftSize = leftSizes(partitionIndex)
|
||||
val isLeftSkew = isSkewed(leftSize, leftMedSize) && canSplitLeft
|
||||
val isLeftSkew = canSplitLeft && leftSize > leftSkewThreshold
|
||||
val rightSize = rightSizes(partitionIndex)
|
||||
val isRightSkew = isSkewed(rightSize, rightMedSize) && canSplitRight
|
||||
val isRightSkew = canSplitRight && rightSize > rightSkewThreshold
|
||||
val noSkewPartitionSpec = Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1))
|
||||
|
||||
val leftParts = if (isLeftSkew) {
|
||||
|
@ -238,8 +239,7 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
|
|||
def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
|
||||
case smj @ SortMergeJoinExec(_, _, joinType, _,
|
||||
s1 @ SortExec(_, _, ShuffleStage(left: ShuffleQueryStageExec), _),
|
||||
s2 @ SortExec(_, _, ShuffleStage(right: ShuffleQueryStageExec), _), isSkewJoin)
|
||||
if !isSkewJoin && supportedJoinTypes.contains(joinType) =>
|
||||
s2 @ SortExec(_, _, ShuffleStage(right: ShuffleQueryStageExec), _), false) =>
|
||||
val newChildren = tryOptimizeJoinChildren(left, right, joinType)
|
||||
if (newChildren.isDefined) {
|
||||
val (newLeft, newRight) = newChildren.get
|
||||
|
@ -251,8 +251,7 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
|
|||
|
||||
case shj @ ShuffledHashJoinExec(_, _, joinType, _, _,
|
||||
ShuffleStage(left: ShuffleQueryStageExec),
|
||||
ShuffleStage(right: ShuffleQueryStageExec), isSkewJoin)
|
||||
if !isSkewJoin && supportedJoinTypes.contains(joinType) =>
|
||||
ShuffleStage(right: ShuffleQueryStageExec), false) =>
|
||||
val newChildren = tryOptimizeJoinChildren(left, right, joinType)
|
||||
if (newChildren.isDefined) {
|
||||
val (newLeft, newRight) = newChildren.get
|
||||
|
|
Loading…
Reference in a new issue