[SPARK-11313][SQL] implement cogroup on DataSets (support 2 datasets)

A simpler version of https://github.com/apache/spark/pull/9279, only support 2 datasets.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9324 from cloud-fan/cogroup2.
This commit is contained in:
Wenchen Fan 2015-10-28 13:58:52 +01:00 committed by Michael Armbrust
parent 5f1cee6f15
commit 075ce4914f
8 changed files with 257 additions and 0 deletions

View file

@ -591,6 +591,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i)));
build.append(',');
}
build.deleteCharAt(build.length() - 1);
build.append(']');
return build.toString();
}

View file

@ -513,3 +513,42 @@ case class MapGroups[K, T, U](
override def missingInput: AttributeSet = AttributeSet.empty
}
/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder](
func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan): CoGroup[K, Left, Right, R] = {
CoGroup(
func,
encoderFor[K],
encoderFor[Left],
encoderFor[Right],
encoderFor[R],
encoderFor[R].schema.toAttributes,
leftGroup,
rightGroup,
left,
right)
}
}
/**
* A relation produced by applying `func` to each grouping key and associated values from left and
* right children.
*/
case class CoGroup[K, Left, Right, R](
func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
kEncoder: ExpressionEncoder[K],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
rEncoder: ExpressionEncoder[R],
output: Seq[Attribute],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode {
override def missingInput: AttributeSet = AttributeSet.empty
}

View file

@ -65,4 +65,24 @@ class GroupedDataset[K, T] private[sql](
sqlContext,
MapGroups(f, groupingAttributes, logicalPlan))
}
/**
* Applies the given function to each cogrouped data. For each unique group, the function will
* be passed the grouping key and 2 iterators containing all elements in the group from
* [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
* arbitrary type which will be returned as a new [[Dataset]].
*/
def cogroup[U, R : Encoder](
other: GroupedDataset[K, U])(
f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = {
implicit def uEnc: Encoder[U] = other.tEncoder
new Dataset[R](
sqlContext,
CoGroup(
f,
this.groupingAttributes,
other.groupingAttributes,
this.logicalPlan,
other.logicalPlan))
}
}

View file

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
/**
* Iterates over [[GroupedIterator]]s and returns the cogrouped data, i.e. each record is a
* grouping key with its associated values from all [[GroupedIterator]]s.
* Note: we assume the output of each [[GroupedIterator]] is ordered by the grouping key.
*/
class CoGroupedIterator(
left: Iterator[(InternalRow, Iterator[InternalRow])],
right: Iterator[(InternalRow, Iterator[InternalRow])],
groupingSchema: Seq[Attribute])
extends Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] {
private val keyOrdering =
GenerateOrdering.generate(groupingSchema.map(SortOrder(_, Ascending)), groupingSchema)
private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
override def hasNext: Boolean = left.hasNext || right.hasNext
override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
if (currentLeftData.eq(null) && left.hasNext) {
currentLeftData = left.next()
}
if (currentRightData.eq(null) && right.hasNext) {
currentRightData = right.next()
}
assert(currentLeftData.ne(null) || currentRightData.ne(null))
if (currentLeftData.eq(null)) {
// left is null, right is not null, consume the right data.
rightOnly()
} else if (currentRightData.eq(null)) {
// left is not null, right is null, consume the left data.
leftOnly()
} else if (currentLeftData._1 == currentRightData._1) {
// left and right have the same grouping key, consume both of them.
val result = (currentLeftData._1, currentLeftData._2, currentRightData._2)
currentLeftData = null
currentRightData = null
result
} else {
val compare = keyOrdering.compare(currentLeftData._1, currentRightData._1)
assert(compare != 0)
if (compare < 0) {
// the grouping key of left is smaller, consume the left data.
leftOnly()
} else {
// the grouping key of right is smaller, consume the right data.
rightOnly()
}
}
}
private def leftOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
val result = (currentLeftData._1, currentLeftData._2, Iterator.empty)
currentLeftData = null
result
}
private def rightOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
val result = (currentRightData._1, Iterator.empty, currentRightData._2)
currentRightData = null
result
}
}

View file

@ -393,6 +393,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output,
leftGroup, rightGroup, left, right) =>
execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup,
planLater(left), planLater(right)) :: Nil
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {

View file

@ -390,3 +390,44 @@ case class MapGroups[K, T, U](
}
}
}
/**
* Co-groups the data from left and right children, and calls the function with each group and 2
* iterators containing all elements in the group from left and right side.
* The result of this function is encoded and flattened before being output.
*/
case class CoGroup[K, Left, Right, R](
func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
kEncoder: ExpressionEncoder[K],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
rEncoder: ExpressionEncoder[R],
output: Seq[Attribute],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
override protected def doExecute(): RDD[InternalRow] = {
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
val groupKeyEncoder = kEncoder.bind(leftGroup)
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
case (key, leftResult, rightResult) =>
val result = func(
groupKeyEncoder.fromRow(key),
leftResult.map(leftEnc.fromRow),
rightResult.map(rightEnc.fromRow))
result.map(rEncoder.toRow)
}
}
}
}

View file

@ -202,4 +202,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
agged,
("a", 30), ("b", 3), ("c", 1))
}
test("cogroup") {
val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()
val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) =>
Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString))
}
checkAnswer(
cogrouped,
1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
}
}

View file

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
test("basic") {
val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator
val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator
val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
val result = cogrouped.map {
case (key, leftData, rightData) =>
assert(key.numFields == 1)
(key.getInt(0), leftData.toSeq, rightData.toSeq)
}.toSeq
assert(result ==
(1,
Seq(create_row(1, "a"), create_row(1, "b")),
Seq(create_row(1, 2L))) ::
(2,
Seq(create_row(2, "c")),
Seq(create_row(2, 3L))) ::
(3,
Seq.empty,
Seq(create_row(3, 4L))) ::
Nil
)
}
}