[SPARK-11410][SQL] Add APIs to provide functionality similar to Hive's DISTRIBUTE BY and SORT BY.

DISTRIBUTE BY allows the user to hash partition the data by specified exprs. It also allows for
optioning sorting within each resulting partition. There is no required relationship between the
exprs for partitioning and sorting (i.e. one does not need to be a prefix of the other).

This patch adds to APIs to DataFrames which can be used together to provide this functionality:
  1. distributeBy() which partitions the data frame into a specified number of partitions using the
     partitioning exprs.
  2. localSort() which sorts each partition using the provided sorting exprs.

To get the DISTRIBUTE BY functionality, the user simply does: df.distributeBy(...).localSort(...)

Author: Nong Li <nongli@gmail.com>

Closes #9364 from nongli/spark-11410.
This commit is contained in:
Nong Li 2015-11-01 14:32:21 -08:00 committed by Yin Huai
parent dc7e399fc0
commit 046e32ed84
4 changed files with 186 additions and 19 deletions

View file

@ -31,10 +31,19 @@ case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan)
extends RedistributeData
/**
* This method repartitions data using [[Expression]]s, and receives information about the
* number of partitions during execution. Used when a specific ordering or distribution is
* expected by the consumer of the query result. Use [[Repartition]] for RDD-like
* This method repartitions data using [[Expression]]s into `numPartitions`, and receives
* information about the number of partitions during execution. Used when a specific ordering or
* distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like
* `coalesce` and `repartition`.
* If `numPartitions` is not specified, the number of partitions will be the number set by
* `spark.sql.shuffle.partitions`.
*/
case class RepartitionByExpression(partitionExpressions: Seq[Expression], child: LogicalPlan)
extends RedistributeData
case class RepartitionByExpression(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
numPartitions: Option[Int] = None) extends RedistributeData {
numPartitions match {
case Some(n) => require(n > 0, "numPartitions must be greater than 0.")
case None => // Ok
}
}

View file

@ -241,6 +241,18 @@ class DataFrame private[sql](
sb.toString()
}
private[sql] def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
}
Sort(sortOrder, global = global, logicalPlan)
}
override def toString: String = {
try {
schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]")
@ -633,15 +645,7 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def sort(sortExprs: Column*): DataFrame = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
}
Sort(sortOrder, global = true, logicalPlan)
sortInternal(true, sortExprs)
}
/**
@ -662,6 +666,44 @@ class DataFrame private[sql](
@scala.annotation.varargs
def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*)
/**
* Returns a new [[DataFrame]] partitioned by the given partitioning expressions into
* `numPartitions`. The resulting DataFrame is hash partitioned.
* @group dfops
* @since 1.6.0
*/
def distributeBy(partitionExprs: Seq[Column], numPartitions: Int): DataFrame = {
RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, Some(numPartitions))
}
/**
* Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving
* the existing number of partitions. The resulting DataFrame is hash partitioned.
* @group dfops
* @since 1.6.0
*/
def distributeBy(partitionExprs: Seq[Column]): DataFrame = {
RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, None)
}
/**
* Returns a new [[DataFrame]] with each partition sorted by the given expressions.
* @group dfops
* @since 1.6.0
*/
@scala.annotation.varargs
def localSort(sortCol: String, sortCols: String*): DataFrame = localSort(sortCol, sortCols : _*)
/**
* Returns a new [[DataFrame]] with each partition sorted by the given expressions.
* @group dfops
* @since 1.6.0
*/
@scala.annotation.varargs
def localSort(sortExprs: Column*): DataFrame = {
sortInternal(false, sortExprs)
}
/**
* Selects column based on the column name and return it as a [[Column]].
* Note that the column name can also reference to a nested column like `a.b`.

View file

@ -27,8 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{SQLContext, Strategy, execution}
import org.apache.spark.sql.{Strategy, execution}
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SparkPlanner =>
@ -455,8 +454,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil
case logical.RepartitionByExpression(expressions, child) =>
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
execution.Exchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil

View file

@ -24,10 +24,14 @@ import scala.util.Random
import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SQLTestData.TestData2
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext}
class DataFrameSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@ -997,4 +1001,116 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}
}
/**
* Verifies that there is no Exchange between the Aggregations for `df`
*/
private def verifyNonExchangingAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
case agg: TungstenAggregate => {
atFirstAgg = !atFirstAgg
}
case _ => {
if (atFirstAgg) {
fail("Should not have operators between the two aggregations")
}
}
}
}
/**
* Verifies that there is an Exchange between the Aggregations for `df`
*/
private def verifyExchangingAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
case agg: TungstenAggregate => {
if (atFirstAgg) {
fail("Should not have back to back Aggregates")
}
atFirstAgg = true
}
case e: Exchange => atFirstAgg = false
case _ =>
}
}
test("distributeBy and localSort") {
val original = testData.repartition(1)
assert(original.rdd.partitions.length == 1)
val df = original.distributeBy(Column("key") :: Nil, 5)
assert(df.rdd.partitions.length == 5)
checkAnswer(original.select(), df.select())
val df2 = original.distributeBy(Column("key") :: Nil, 10)
assert(df2.rdd.partitions.length == 10)
checkAnswer(original.select(), df2.select())
// Group by the column we are distributed by. This should generate a plan with no exchange
// between the aggregates
val df3 = testData.distributeBy(Column("key") :: Nil).groupBy("key").count()
verifyNonExchangingAgg(df3)
verifyNonExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil)
.groupBy("key", "value").count())
// Grouping by just the first distributeBy expr, need to exchange.
verifyExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil)
.groupBy("key").count())
val data = sqlContext.sparkContext.parallelize(
(1 to 100).map(i => TestData2(i % 10, i))).toDF()
// Distribute and order by.
val df4 = data.distributeBy(Column("a") :: Nil).localSort($"b".desc)
// Walk each partition and verify that it is sorted descending and does not contain all
// the values.
df4.rdd.foreachPartition(p => {
var previousValue: Int = -1
var allSequential: Boolean = true
p.foreach(r => {
val v: Int = r.getInt(1)
if (previousValue != -1) {
if (previousValue < v) throw new SparkException("Partition is not ordered.")
if (v + 1 != previousValue) allSequential = false
}
previousValue = v
})
if (allSequential) throw new SparkException("Partition should not be globally ordered")
})
// Distribute and order by with multiple order bys
val df5 = data.distributeBy(Column("a") :: Nil, 2).localSort($"b".asc, $"a".asc)
// Walk each partition and verify that it is sorted ascending
df5.rdd.foreachPartition(p => {
var previousValue: Int = -1
var allSequential: Boolean = true
p.foreach(r => {
val v: Int = r.getInt(1)
if (previousValue != -1) {
if (previousValue > v) throw new SparkException("Partition is not ordered.")
if (v - 1 != previousValue) allSequential = false
}
previousValue = v
})
if (allSequential) throw new SparkException("Partition should not be all sequential")
})
// Distribute into one partition and order by. This partition should contain all the values.
val df6 = data.distributeBy(Column("a") :: Nil, 1).localSort($"b".asc)
// Walk each partition and verify that it is sorted descending and not globally sorted.
df6.rdd.foreachPartition(p => {
var previousValue: Int = -1
var allSequential: Boolean = true
p.foreach(r => {
val v: Int = r.getInt(1)
if (previousValue != -1) {
if (previousValue > v) throw new SparkException("Partition is not ordered.")
if (v - 1 != previousValue) allSequential = false
}
previousValue = v
})
if (!allSequential) throw new SparkException("Partition should contain all sequential values")
})
}
}