[SPARK-11313][SQL] implement cogroup on DataSets (support 2 datasets)
A simpler version of https://github.com/apache/spark/pull/9279, only support 2 datasets. Author: Wenchen Fan <wenchen@databricks.com> Closes #9324 from cloud-fan/cogroup2.
This commit is contained in:
parent
5f1cee6f15
commit
075ce4914f
|
@ -591,6 +591,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
|
|||
build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i)));
|
||||
build.append(',');
|
||||
}
|
||||
build.deleteCharAt(build.length() - 1);
|
||||
build.append(']');
|
||||
return build.toString();
|
||||
}
|
||||
|
|
|
@ -513,3 +513,42 @@ case class MapGroups[K, T, U](
|
|||
override def missingInput: AttributeSet = AttributeSet.empty
|
||||
}
|
||||
|
||||
/** Factory for constructing new `CoGroup` nodes. */
|
||||
object CoGroup {
|
||||
def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder](
|
||||
func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
|
||||
leftGroup: Seq[Attribute],
|
||||
rightGroup: Seq[Attribute],
|
||||
left: LogicalPlan,
|
||||
right: LogicalPlan): CoGroup[K, Left, Right, R] = {
|
||||
CoGroup(
|
||||
func,
|
||||
encoderFor[K],
|
||||
encoderFor[Left],
|
||||
encoderFor[Right],
|
||||
encoderFor[R],
|
||||
encoderFor[R].schema.toAttributes,
|
||||
leftGroup,
|
||||
rightGroup,
|
||||
left,
|
||||
right)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A relation produced by applying `func` to each grouping key and associated values from left and
|
||||
* right children.
|
||||
*/
|
||||
case class CoGroup[K, Left, Right, R](
|
||||
func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
|
||||
kEncoder: ExpressionEncoder[K],
|
||||
leftEnc: ExpressionEncoder[Left],
|
||||
rightEnc: ExpressionEncoder[Right],
|
||||
rEncoder: ExpressionEncoder[R],
|
||||
output: Seq[Attribute],
|
||||
leftGroup: Seq[Attribute],
|
||||
rightGroup: Seq[Attribute],
|
||||
left: LogicalPlan,
|
||||
right: LogicalPlan) extends BinaryNode {
|
||||
override def missingInput: AttributeSet = AttributeSet.empty
|
||||
}
|
||||
|
|
|
@ -65,4 +65,24 @@ class GroupedDataset[K, T] private[sql](
|
|||
sqlContext,
|
||||
MapGroups(f, groupingAttributes, logicalPlan))
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the given function to each cogrouped data. For each unique group, the function will
|
||||
* be passed the grouping key and 2 iterators containing all elements in the group from
|
||||
* [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
|
||||
* arbitrary type which will be returned as a new [[Dataset]].
|
||||
*/
|
||||
def cogroup[U, R : Encoder](
|
||||
other: GroupedDataset[K, U])(
|
||||
f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = {
|
||||
implicit def uEnc: Encoder[U] = other.tEncoder
|
||||
new Dataset[R](
|
||||
sqlContext,
|
||||
CoGroup(
|
||||
f,
|
||||
this.groupingAttributes,
|
||||
other.groupingAttributes,
|
||||
this.logicalPlan,
|
||||
other.logicalPlan))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
|
||||
|
||||
/**
|
||||
* Iterates over [[GroupedIterator]]s and returns the cogrouped data, i.e. each record is a
|
||||
* grouping key with its associated values from all [[GroupedIterator]]s.
|
||||
* Note: we assume the output of each [[GroupedIterator]] is ordered by the grouping key.
|
||||
*/
|
||||
class CoGroupedIterator(
|
||||
left: Iterator[(InternalRow, Iterator[InternalRow])],
|
||||
right: Iterator[(InternalRow, Iterator[InternalRow])],
|
||||
groupingSchema: Seq[Attribute])
|
||||
extends Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] {
|
||||
|
||||
private val keyOrdering =
|
||||
GenerateOrdering.generate(groupingSchema.map(SortOrder(_, Ascending)), groupingSchema)
|
||||
|
||||
private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
|
||||
private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
|
||||
|
||||
override def hasNext: Boolean = left.hasNext || right.hasNext
|
||||
|
||||
override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
|
||||
if (currentLeftData.eq(null) && left.hasNext) {
|
||||
currentLeftData = left.next()
|
||||
}
|
||||
if (currentRightData.eq(null) && right.hasNext) {
|
||||
currentRightData = right.next()
|
||||
}
|
||||
|
||||
assert(currentLeftData.ne(null) || currentRightData.ne(null))
|
||||
|
||||
if (currentLeftData.eq(null)) {
|
||||
// left is null, right is not null, consume the right data.
|
||||
rightOnly()
|
||||
} else if (currentRightData.eq(null)) {
|
||||
// left is not null, right is null, consume the left data.
|
||||
leftOnly()
|
||||
} else if (currentLeftData._1 == currentRightData._1) {
|
||||
// left and right have the same grouping key, consume both of them.
|
||||
val result = (currentLeftData._1, currentLeftData._2, currentRightData._2)
|
||||
currentLeftData = null
|
||||
currentRightData = null
|
||||
result
|
||||
} else {
|
||||
val compare = keyOrdering.compare(currentLeftData._1, currentRightData._1)
|
||||
assert(compare != 0)
|
||||
if (compare < 0) {
|
||||
// the grouping key of left is smaller, consume the left data.
|
||||
leftOnly()
|
||||
} else {
|
||||
// the grouping key of right is smaller, consume the right data.
|
||||
rightOnly()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def leftOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
|
||||
val result = (currentLeftData._1, currentLeftData._2, Iterator.empty)
|
||||
currentLeftData = null
|
||||
result
|
||||
}
|
||||
|
||||
private def rightOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
|
||||
val result = (currentRightData._1, Iterator.empty, currentRightData._2)
|
||||
currentRightData = null
|
||||
result
|
||||
}
|
||||
}
|
|
@ -393,6 +393,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
|
||||
case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
|
||||
execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
|
||||
case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output,
|
||||
leftGroup, rightGroup, left, right) =>
|
||||
execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup,
|
||||
planLater(left), planLater(right)) :: Nil
|
||||
|
||||
case logical.Repartition(numPartitions, shuffle, child) =>
|
||||
if (shuffle) {
|
||||
|
|
|
@ -390,3 +390,44 @@ case class MapGroups[K, T, U](
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Co-groups the data from left and right children, and calls the function with each group and 2
|
||||
* iterators containing all elements in the group from left and right side.
|
||||
* The result of this function is encoded and flattened before being output.
|
||||
*/
|
||||
case class CoGroup[K, Left, Right, R](
|
||||
func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
|
||||
kEncoder: ExpressionEncoder[K],
|
||||
leftEnc: ExpressionEncoder[Left],
|
||||
rightEnc: ExpressionEncoder[Right],
|
||||
rEncoder: ExpressionEncoder[R],
|
||||
output: Seq[Attribute],
|
||||
leftGroup: Seq[Attribute],
|
||||
rightGroup: Seq[Attribute],
|
||||
left: SparkPlan,
|
||||
right: SparkPlan) extends BinaryNode {
|
||||
|
||||
override def requiredChildDistribution: Seq[Distribution] =
|
||||
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
|
||||
|
||||
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
|
||||
leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
|
||||
|
||||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
|
||||
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
|
||||
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
|
||||
val groupKeyEncoder = kEncoder.bind(leftGroup)
|
||||
|
||||
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
|
||||
case (key, leftResult, rightResult) =>
|
||||
val result = func(
|
||||
groupKeyEncoder.fromRow(key),
|
||||
leftResult.map(leftEnc.fromRow),
|
||||
rightResult.map(rightEnc.fromRow))
|
||||
result.map(rEncoder.toRow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -202,4 +202,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
agged,
|
||||
("a", 30), ("b", 3), ("c", 1))
|
||||
}
|
||||
|
||||
test("cogroup") {
|
||||
val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
|
||||
val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()
|
||||
val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) =>
|
||||
Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString))
|
||||
}
|
||||
|
||||
checkAnswer(
|
||||
cogrouped,
|
||||
1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
|
||||
|
||||
class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||
|
||||
test("basic") {
|
||||
val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator
|
||||
val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator
|
||||
val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
|
||||
val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
|
||||
val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
|
||||
|
||||
val result = cogrouped.map {
|
||||
case (key, leftData, rightData) =>
|
||||
assert(key.numFields == 1)
|
||||
(key.getInt(0), leftData.toSeq, rightData.toSeq)
|
||||
}.toSeq
|
||||
assert(result ==
|
||||
(1,
|
||||
Seq(create_row(1, "a"), create_row(1, "b")),
|
||||
Seq(create_row(1, 2L))) ::
|
||||
(2,
|
||||
Seq(create_row(2, "c")),
|
||||
Seq(create_row(2, 3L))) ::
|
||||
(3,
|
||||
Seq.empty,
|
||||
Seq(create_row(3, 4L))) ::
|
||||
Nil
|
||||
)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue