[SPARK-9373][SQL] Support StructType in Tungsten projection
This pull request updates GenerateUnsafeProjection to support StructType. If an input struct type is backed already by an UnsafeRow, GenerateUnsafeProjection copies the bytes directly into its buffer space without any conversion. However, if the input is not an UnsafeRow, GenerateUnsafeProjection runs the code generated recursively to convert the input into an UnsafeRow and then copies it into the buffer space. Also create a TungstenProject operator that projects data directly into UnsafeRow. Note that I'm not sure if this is the way we want to structure Unsafe+codegen operators, but we can defer that decision to follow-up pull requests. Author: Reynold Xin <rxin@databricks.com> Closes #7689 from rxin/tungsten-struct-type and squashes the following commits: 9162f42 [Reynold Xin] Support IntervalType in UnsafeRow's getter. be9f377 [Reynold Xin] Fixed tests. 10c4b7c [Reynold Xin] Format generated code. 77e8d0e [Reynold Xin] Fixed NondeterministicSuite. ac4951d [Reynold Xin] Yay. ac203bf [Reynold Xin] More comments. 9f36216 [Reynold Xin] Updated comment. 6b781fe [Reynold Xin] Reset the change in DataFrameSuite. 525b95b [Reynold Xin] Merged with master, more documentation & test cases. 321859a [Reynold Xin] [SPARK-9373][SQL] Support StructType in Tungsten projection [WIP]
This commit is contained in:
parent
63a492b931
commit
60f08c7c87
|
@ -265,6 +265,8 @@ public final class UnsafeRow extends MutableRow {
|
||||||
return getBinary(ordinal);
|
return getBinary(ordinal);
|
||||||
} else if (dataType instanceof StringType) {
|
} else if (dataType instanceof StringType) {
|
||||||
return getUTF8String(ordinal);
|
return getUTF8String(ordinal);
|
||||||
|
} else if (dataType instanceof IntervalType) {
|
||||||
|
return getInterval(ordinal);
|
||||||
} else if (dataType instanceof StructType) {
|
} else if (dataType instanceof StructType) {
|
||||||
return getStruct(ordinal, ((StructType) dataType).size());
|
return getStruct(ordinal, ((StructType) dataType).size());
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.catalyst.expressions;
|
package org.apache.spark.sql.catalyst.expressions;
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.InternalRow;
|
||||||
import org.apache.spark.unsafe.PlatformDependent;
|
import org.apache.spark.unsafe.PlatformDependent;
|
||||||
import org.apache.spark.unsafe.array.ByteArrayMethods;
|
import org.apache.spark.unsafe.array.ByteArrayMethods;
|
||||||
import org.apache.spark.unsafe.types.ByteArray;
|
import org.apache.spark.unsafe.types.ByteArray;
|
||||||
|
@ -81,6 +82,52 @@ public class UnsafeRowWriters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writer for struct type where the struct field is backed by an {@link UnsafeRow}.
|
||||||
|
*
|
||||||
|
* We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}.
|
||||||
|
* Non-UnsafeRow struct fields are handled directly in
|
||||||
|
* {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}
|
||||||
|
* by generating the Java code needed to convert them into UnsafeRow.
|
||||||
|
*/
|
||||||
|
public static class StructWriter {
|
||||||
|
public static int getSize(InternalRow input) {
|
||||||
|
int numBytes = 0;
|
||||||
|
if (input instanceof UnsafeRow) {
|
||||||
|
numBytes = ((UnsafeRow) input).getSizeInBytes();
|
||||||
|
} else {
|
||||||
|
// This is handled directly in GenerateUnsafeProjection.
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) {
|
||||||
|
int numBytes = 0;
|
||||||
|
final long offset = target.getBaseOffset() + cursor;
|
||||||
|
if (input instanceof UnsafeRow) {
|
||||||
|
final UnsafeRow row = (UnsafeRow) input;
|
||||||
|
numBytes = row.getSizeInBytes();
|
||||||
|
|
||||||
|
// zero-out the padding bytes
|
||||||
|
if ((numBytes & 0x07) > 0) {
|
||||||
|
PlatformDependent.UNSAFE.putLong(
|
||||||
|
target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the string to the variable length portion.
|
||||||
|
row.writeToMemory(target.getBaseObject(), offset);
|
||||||
|
|
||||||
|
// Set the fixed length portion.
|
||||||
|
target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
|
||||||
|
} else {
|
||||||
|
// This is handled directly in GenerateUnsafeProjection.
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Writer for interval type. */
|
/** Writer for interval type. */
|
||||||
public static class IntervalWriter {
|
public static class IntervalWriter {
|
||||||
|
|
||||||
|
@ -96,5 +143,4 @@ public class UnsafeRowWriters {
|
||||||
return 16;
|
return 16;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,7 +50,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
|
||||||
case BinaryType => input.getBinary(ordinal)
|
case BinaryType => input.getBinary(ordinal)
|
||||||
case IntervalType => input.getInterval(ordinal)
|
case IntervalType => input.getInterval(ordinal)
|
||||||
case t: StructType => input.getStruct(ordinal, t.size)
|
case t: StructType => input.getStruct(ordinal, t.size)
|
||||||
case dataType => input.get(ordinal, dataType)
|
case _ => input.get(ordinal, dataType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -64,10 +64,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
|
||||||
override def exprId: ExprId = throw new UnsupportedOperationException
|
override def exprId: ExprId = throw new UnsupportedOperationException
|
||||||
|
|
||||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||||
|
val javaType = ctx.javaType(dataType)
|
||||||
|
val value = ctx.getColumn("i", dataType, ordinal)
|
||||||
s"""
|
s"""
|
||||||
boolean ${ev.isNull} = i.isNullAt($ordinal);
|
boolean ${ev.isNull} = i.isNullAt($ordinal);
|
||||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
|
$javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
|
||||||
${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)});
|
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,11 +34,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
||||||
private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName
|
private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName
|
||||||
private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
|
private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
|
||||||
private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
|
private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
|
||||||
|
private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName
|
||||||
|
|
||||||
/** Returns true iff we support this data type. */
|
/** Returns true iff we support this data type. */
|
||||||
def canSupport(dataType: DataType): Boolean = dataType match {
|
def canSupport(dataType: DataType): Boolean = dataType match {
|
||||||
case t: AtomicType if !t.isInstanceOf[DecimalType] => true
|
case t: AtomicType if !t.isInstanceOf[DecimalType] => true
|
||||||
case _: IntervalType => true
|
case _: IntervalType => true
|
||||||
|
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
|
||||||
case NullType => true
|
case NullType => true
|
||||||
case _ => false
|
case _ => false
|
||||||
}
|
}
|
||||||
|
@ -55,15 +57,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
||||||
|
|
||||||
val ret = ev.primitive
|
val ret = ev.primitive
|
||||||
ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();")
|
ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();")
|
||||||
val bufferTerm = ctx.freshName("buffer")
|
val buffer = ctx.freshName("buffer")
|
||||||
ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];")
|
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
|
||||||
val cursorTerm = ctx.freshName("cursor")
|
val cursor = ctx.freshName("cursor")
|
||||||
val numBytesTerm = ctx.freshName("numBytes")
|
val numBytes = ctx.freshName("numBytes")
|
||||||
|
|
||||||
val exprs = expressions.map(_.gen(ctx))
|
val exprs = expressions.zipWithIndex.map { case (e, i) =>
|
||||||
|
e.dataType match {
|
||||||
|
case st: StructType =>
|
||||||
|
createCodeForStruct(ctx, e.gen(ctx), st)
|
||||||
|
case _ =>
|
||||||
|
e.gen(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
val allExprs = exprs.map(_.code).mkString("\n")
|
val allExprs = exprs.map(_.code).mkString("\n")
|
||||||
val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
|
|
||||||
|
|
||||||
|
val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
|
||||||
val additionalSize = expressions.zipWithIndex.map { case (e, i) =>
|
val additionalSize = expressions.zipWithIndex.map { case (e, i) =>
|
||||||
e.dataType match {
|
e.dataType match {
|
||||||
case StringType =>
|
case StringType =>
|
||||||
|
@ -72,6 +81,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
||||||
s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))"
|
s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))"
|
||||||
case IntervalType =>
|
case IntervalType =>
|
||||||
s" + (${exprs(i).isNull} ? 0 : 16)"
|
s" + (${exprs(i).isNull} ? 0 : 16)"
|
||||||
|
case _: StructType =>
|
||||||
|
s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))"
|
||||||
case _ => ""
|
case _ => ""
|
||||||
}
|
}
|
||||||
}.mkString("")
|
}.mkString("")
|
||||||
|
@ -81,11 +92,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
||||||
case dt if ctx.isPrimitiveType(dt) =>
|
case dt if ctx.isPrimitiveType(dt) =>
|
||||||
s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}"
|
s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}"
|
||||||
case StringType =>
|
case StringType =>
|
||||||
s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
|
s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
|
||||||
case BinaryType =>
|
case BinaryType =>
|
||||||
s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
|
s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
|
||||||
case IntervalType =>
|
case IntervalType =>
|
||||||
s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
|
s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
|
||||||
|
case t: StructType =>
|
||||||
|
s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
|
||||||
case NullType => ""
|
case NullType => ""
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
|
throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
|
||||||
|
@ -99,24 +112,139 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
||||||
|
|
||||||
s"""
|
s"""
|
||||||
$allExprs
|
$allExprs
|
||||||
int $numBytesTerm = $fixedSize $additionalSize;
|
int $numBytes = $fixedSize $additionalSize;
|
||||||
if ($numBytesTerm > $bufferTerm.length) {
|
if ($numBytes > $buffer.length) {
|
||||||
$bufferTerm = new byte[$numBytesTerm];
|
$buffer = new byte[$numBytes];
|
||||||
}
|
}
|
||||||
|
|
||||||
$ret.pointTo(
|
$ret.pointTo(
|
||||||
$bufferTerm,
|
$buffer,
|
||||||
org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
|
org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
|
||||||
${expressions.size},
|
${expressions.size},
|
||||||
$numBytesTerm);
|
$numBytes);
|
||||||
int $cursorTerm = $fixedSize;
|
int $cursor = $fixedSize;
|
||||||
|
|
||||||
|
|
||||||
$writers
|
$writers
|
||||||
boolean ${ev.isNull} = false;
|
boolean ${ev.isNull} = false;
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow.
|
||||||
|
*
|
||||||
|
* This function also handles nested structs by recursively generating the code to do conversion.
|
||||||
|
*
|
||||||
|
* @param ctx code generation context
|
||||||
|
* @param input the input struct, identified by a [[GeneratedExpressionCode]]
|
||||||
|
* @param schema schema of the struct field
|
||||||
|
*/
|
||||||
|
// TODO: refactor createCode and this function to reduce code duplication.
|
||||||
|
private def createCodeForStruct(
|
||||||
|
ctx: CodeGenContext,
|
||||||
|
input: GeneratedExpressionCode,
|
||||||
|
schema: StructType): GeneratedExpressionCode = {
|
||||||
|
|
||||||
|
val isNull = input.isNull
|
||||||
|
val primitive = ctx.freshName("structConvert")
|
||||||
|
ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();")
|
||||||
|
val buffer = ctx.freshName("buffer")
|
||||||
|
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
|
||||||
|
val cursor = ctx.freshName("cursor")
|
||||||
|
|
||||||
|
val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map {
|
||||||
|
case (dt, i) => dt match {
|
||||||
|
case st: StructType =>
|
||||||
|
val nestedStructEv = GeneratedExpressionCode(
|
||||||
|
code = "",
|
||||||
|
isNull = s"${input.primitive}.isNullAt($i)",
|
||||||
|
primitive = s"${ctx.getColumn(input.primitive, dt, i)}"
|
||||||
|
)
|
||||||
|
createCodeForStruct(ctx, nestedStructEv, st)
|
||||||
|
case _ =>
|
||||||
|
GeneratedExpressionCode(
|
||||||
|
code = "",
|
||||||
|
isNull = s"${input.primitive}.isNullAt($i)",
|
||||||
|
primitive = s"${ctx.getColumn(input.primitive, dt, i)}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val allExprs = exprs.map(_.code).mkString("\n")
|
||||||
|
|
||||||
|
val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
|
||||||
|
val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) =>
|
||||||
|
dt match {
|
||||||
|
case StringType =>
|
||||||
|
s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
|
||||||
|
case BinaryType =>
|
||||||
|
s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
|
||||||
|
case IntervalType =>
|
||||||
|
s" + (${ev.isNull} ? 0 : 16)"
|
||||||
|
case _: StructType =>
|
||||||
|
s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
|
||||||
|
case _ => ""
|
||||||
|
}
|
||||||
|
}.mkString("")
|
||||||
|
|
||||||
|
val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) =>
|
||||||
|
val update = dt match {
|
||||||
|
case _ if ctx.isPrimitiveType(dt) =>
|
||||||
|
s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}"
|
||||||
|
case StringType =>
|
||||||
|
s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
|
||||||
|
case BinaryType =>
|
||||||
|
s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
|
||||||
|
case IntervalType =>
|
||||||
|
s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
|
||||||
|
case t: StructType =>
|
||||||
|
s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
|
||||||
|
case NullType => ""
|
||||||
|
case _ =>
|
||||||
|
throw new UnsupportedOperationException(s"Not supported DataType: $dt")
|
||||||
|
}
|
||||||
|
s"""
|
||||||
|
if (${exprs(i).isNull}) {
|
||||||
|
$primitive.setNullAt($i);
|
||||||
|
} else {
|
||||||
|
$update;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
}.mkString("\n ")
|
||||||
|
|
||||||
|
// Note that we add a shortcut here for performance: if the input is already an UnsafeRow,
|
||||||
|
// just copy the bytes directly into our buffer space without running any conversion.
|
||||||
|
// We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from
|
||||||
|
// complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow.
|
||||||
|
val tmp = ctx.freshName("tmp")
|
||||||
|
val numBytes = ctx.freshName("numBytes")
|
||||||
|
val code = s"""
|
||||||
|
|${input.code}
|
||||||
|
|if (!${input.isNull}) {
|
||||||
|
| Object $tmp = (Object) ${input.primitive};
|
||||||
|
| if ($tmp instanceof UnsafeRow) {
|
||||||
|
| $primitive = (UnsafeRow) $tmp;
|
||||||
|
| } else {
|
||||||
|
| $allExprs
|
||||||
|
|
|
||||||
|
| int $numBytes = $fixedSize $additionalSize;
|
||||||
|
| if ($numBytes > $buffer.length) {
|
||||||
|
| $buffer = new byte[$numBytes];
|
||||||
|
| }
|
||||||
|
|
|
||||||
|
| $primitive.pointTo(
|
||||||
|
| $buffer,
|
||||||
|
| org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
|
||||||
|
| ${exprs.size},
|
||||||
|
| $numBytes);
|
||||||
|
| int $cursor = $fixedSize;
|
||||||
|
|
|
||||||
|
| $writers
|
||||||
|
| }
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
|
||||||
|
GeneratedExpressionCode(code, isNull, primitive)
|
||||||
|
}
|
||||||
|
|
||||||
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
|
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
|
||||||
in.map(ExpressionCanonicalizer.execute)
|
in.map(ExpressionCanonicalizer.execute)
|
||||||
|
|
||||||
|
@ -159,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logDebug(s"code for ${expressions.mkString(",")}:\n$code")
|
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
|
||||||
|
|
||||||
val c = compile(code)
|
val c = compile(code)
|
||||||
c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
|
c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
|
||||||
|
|
|
@ -116,6 +116,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
|
||||||
override def prettyName: String = "struct"
|
override def prettyName: String = "struct"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a struct with the given field names and values
|
* Creates a struct with the given field names and values
|
||||||
*
|
*
|
||||||
|
@ -179,3 +180,72 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
|
||||||
|
|
||||||
override def prettyName: String = "named_struct"
|
override def prettyName: String = "named_struct"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a Row containing the evaluation of all children expressions. This is a variant that
|
||||||
|
* returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
|
||||||
|
* this expression automatically at runtime.
|
||||||
|
*/
|
||||||
|
case class CreateStructUnsafe(children: Seq[Expression]) extends Expression {
|
||||||
|
|
||||||
|
override def foldable: Boolean = children.forall(_.foldable)
|
||||||
|
|
||||||
|
override lazy val resolved: Boolean = childrenResolved
|
||||||
|
|
||||||
|
override lazy val dataType: StructType = {
|
||||||
|
val fields = children.zipWithIndex.map { case (child, idx) =>
|
||||||
|
child match {
|
||||||
|
case ne: NamedExpression =>
|
||||||
|
StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
|
||||||
|
case _ =>
|
||||||
|
StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StructType(fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def nullable: Boolean = false
|
||||||
|
|
||||||
|
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
|
||||||
|
|
||||||
|
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||||
|
GenerateUnsafeProjection.createCode(ctx, ev, children)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "struct_unsafe"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a struct with the given field names and values. This is a variant that returns
|
||||||
|
* UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
|
||||||
|
* this expression automatically at runtime.
|
||||||
|
*
|
||||||
|
* @param children Seq(name1, val1, name2, val2, ...)
|
||||||
|
*/
|
||||||
|
case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression {
|
||||||
|
|
||||||
|
private lazy val (nameExprs, valExprs) =
|
||||||
|
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
|
||||||
|
|
||||||
|
private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
|
||||||
|
|
||||||
|
override lazy val dataType: StructType = {
|
||||||
|
val fields = names.zip(valExprs).map { case (name, valExpr) =>
|
||||||
|
StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
|
||||||
|
}
|
||||||
|
StructType(fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def foldable: Boolean = valExprs.forall(_.foldable)
|
||||||
|
|
||||||
|
override def nullable: Boolean = false
|
||||||
|
|
||||||
|
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
|
||||||
|
|
||||||
|
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||||
|
GenerateUnsafeProjection.createCode(ctx, ev, valExprs)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "named_struct_unsafe"
|
||||||
|
}
|
||||||
|
|
|
@ -170,6 +170,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
|
||||||
checkEvaluation(Pmod(-7, 3), 2)
|
checkEvaluation(Pmod(-7, 3), 2)
|
||||||
checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005)
|
checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005)
|
||||||
checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1))
|
checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1))
|
||||||
checkEvaluation(Pmod(2L, Long.MaxValue), 2)
|
checkEvaluation(Pmod(2L, Long.MaxValue), 2L)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,8 +30,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
checkEvaluation(expr, expected)
|
checkEvaluation(expr, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
check(1.toByte, ~1.toByte)
|
// Need the extra toByte even though IntelliJ thought it's not needed.
|
||||||
check(1000.toShort, ~1000.toShort)
|
check(1.toByte, (~1.toByte).toByte)
|
||||||
|
check(1000.toShort, (~1000.toShort).toShort)
|
||||||
check(1000000, ~1000000)
|
check(1000000, ~1000000)
|
||||||
check(123456789123L, ~123456789123L)
|
check(123456789123L, ~123456789123L)
|
||||||
|
|
||||||
|
@ -45,8 +46,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
checkEvaluation(expr, expected)
|
checkEvaluation(expr, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
check(1.toByte, 2.toByte, 1.toByte & 2.toByte)
|
// Need the extra toByte even though IntelliJ thought it's not needed.
|
||||||
check(1000.toShort, 2.toShort, 1000.toShort & 2.toShort)
|
check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte)
|
||||||
|
check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort)
|
||||||
check(1000000, 4, 1000000 & 4)
|
check(1000000, 4, 1000000 & 4)
|
||||||
check(123456789123L, 5L, 123456789123L & 5L)
|
check(123456789123L, 5L, 123456789123L & 5L)
|
||||||
|
|
||||||
|
@ -63,8 +65,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
checkEvaluation(expr, expected)
|
checkEvaluation(expr, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
check(1.toByte, 2.toByte, 1.toByte | 2.toByte)
|
// Need the extra toByte even though IntelliJ thought it's not needed.
|
||||||
check(1000.toShort, 2.toShort, 1000.toShort | 2.toShort)
|
check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte)
|
||||||
|
check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort)
|
||||||
check(1000000, 4, 1000000 | 4)
|
check(1000000, 4, 1000000 | 4)
|
||||||
check(123456789123L, 5L, 123456789123L | 5L)
|
check(123456789123L, 5L, 123456789123L | 5L)
|
||||||
|
|
||||||
|
@ -81,8 +84,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
checkEvaluation(expr, expected)
|
checkEvaluation(expr, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
check(1.toByte, 2.toByte, 1.toByte ^ 2.toByte)
|
// Need the extra toByte even though IntelliJ thought it's not needed.
|
||||||
check(1000.toShort, 2.toShort, 1000.toShort ^ 2.toShort)
|
check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte)
|
||||||
|
check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort)
|
||||||
check(1000000, 4, 1000000 ^ 4)
|
check(1000000, 4, 1000000 ^ 4)
|
||||||
check(123456789123L, 5L, 123456789123L ^ 5L)
|
check(123456789123L, 5L, 123456789123L ^ 5L)
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,7 @@ trait ExpressionEvalHelper {
|
||||||
val actual = plan(inputRow).get(0, expression.dataType)
|
val actual = plan(inputRow).get(0, expression.dataType)
|
||||||
if (!checkResult(actual, expected)) {
|
if (!checkResult(actual, expected)) {
|
||||||
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
|
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
|
||||||
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
|
fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,7 +146,8 @@ trait ExpressionEvalHelper {
|
||||||
|
|
||||||
if (actual != expectedRow) {
|
if (actual != expectedRow) {
|
||||||
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
|
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
|
||||||
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input")
|
fail("Incorrect Evaluation in codegen mode: " +
|
||||||
|
s"$expression, actual: $actual, expected: $expectedRow$input")
|
||||||
}
|
}
|
||||||
if (actual.copy() != expectedRow) {
|
if (actual.copy() != expectedRow) {
|
||||||
fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow")
|
fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow")
|
||||||
|
@ -163,12 +164,21 @@ trait ExpressionEvalHelper {
|
||||||
expression)
|
expression)
|
||||||
|
|
||||||
val unsafeRow = plan(inputRow)
|
val unsafeRow = plan(inputRow)
|
||||||
// UnsafeRow cannot be compared with GenericInternalRow directly
|
|
||||||
val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow)
|
|
||||||
val expectedRow = InternalRow(expected)
|
|
||||||
if (actual != expectedRow) {
|
|
||||||
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
|
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
|
||||||
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input")
|
|
||||||
|
if (expected == null) {
|
||||||
|
if (!unsafeRow.isNullAt(0)) {
|
||||||
|
val expectedRow = InternalRow(expected)
|
||||||
|
fail("Incorrect evaluation in unsafe mode: " +
|
||||||
|
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
val lit = InternalRow(expected)
|
||||||
|
val expectedRow = UnsafeProjection.create(Array(expression.dataType)).apply(lit)
|
||||||
|
if (unsafeRow != expectedRow) {
|
||||||
|
fail("Incorrect evaluation in unsafe mode: " +
|
||||||
|
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -363,7 +363,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
||||||
case logical.Sort(sortExprs, global, child) =>
|
case logical.Sort(sortExprs, global, child) =>
|
||||||
getSortOperator(sortExprs, global, planLater(child)):: Nil
|
getSortOperator(sortExprs, global, planLater(child)):: Nil
|
||||||
case logical.Project(projectList, child) =>
|
case logical.Project(projectList, child) =>
|
||||||
|
// If unsafe mode is enabled and we support these data types in Unsafe, use the
|
||||||
|
// Tungsten project. Otherwise, use the normal project.
|
||||||
|
if (sqlContext.conf.unsafeEnabled &&
|
||||||
|
UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) {
|
||||||
|
execution.TungstenProject(projectList, planLater(child)) :: Nil
|
||||||
|
} else {
|
||||||
execution.Project(projectList, planLater(child)) :: Nil
|
execution.Project(projectList, planLater(child)) :: Nil
|
||||||
|
}
|
||||||
case logical.Filter(condition, child) =>
|
case logical.Filter(condition, child) =>
|
||||||
execution.Filter(condition, planLater(child)) :: Nil
|
execution.Filter(condition, planLater(child)) :: Nil
|
||||||
case e @ logical.Expand(_, _, _, child) =>
|
case e @ logical.Expand(_, _, _, child) =>
|
||||||
|
|
|
@ -49,6 +49,31 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
|
||||||
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
|
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A variant of [[Project]] that returns [[UnsafeRow]]s.
|
||||||
|
*/
|
||||||
|
case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
|
||||||
|
|
||||||
|
override def outputsUnsafeRows: Boolean = true
|
||||||
|
override def canProcessUnsafeRows: Boolean = true
|
||||||
|
override def canProcessSafeRows: Boolean = true
|
||||||
|
|
||||||
|
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
|
||||||
|
|
||||||
|
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
|
||||||
|
this.transformAllExpressions {
|
||||||
|
case CreateStruct(children) => CreateStructUnsafe(children)
|
||||||
|
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
|
||||||
|
}
|
||||||
|
val project = UnsafeProjection.create(projectList, child.output)
|
||||||
|
iter.map(project)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* :: DeveloperApi ::
|
* :: DeveloperApi ::
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql
|
||||||
|
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
import org.apache.spark.sql.test.SQLTestUtils
|
||||||
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An end-to-end test suite specifically for testing Tungsten (Unsafe/CodeGen) mode.
|
||||||
|
*
|
||||||
|
* This is here for now so I can make sure Tungsten project is tested without refactoring existing
|
||||||
|
* end-to-end test infra. In the long run this should just go away.
|
||||||
|
*/
|
||||||
|
class DataFrameTungstenSuite extends QueryTest with SQLTestUtils {
|
||||||
|
|
||||||
|
override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
|
||||||
|
import sqlContext.implicits._
|
||||||
|
|
||||||
|
test("test simple types") {
|
||||||
|
withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
|
||||||
|
val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b")
|
||||||
|
assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test struct type") {
|
||||||
|
withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
|
||||||
|
val struct = Row(1, 2L, 3.0F, 3.0)
|
||||||
|
val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct)))
|
||||||
|
|
||||||
|
val schema = new StructType()
|
||||||
|
.add("a", IntegerType)
|
||||||
|
.add("b",
|
||||||
|
new StructType()
|
||||||
|
.add("b1", IntegerType)
|
||||||
|
.add("b2", LongType)
|
||||||
|
.add("b3", FloatType)
|
||||||
|
.add("b4", DoubleType))
|
||||||
|
|
||||||
|
val df = sqlContext.createDataFrame(data, schema)
|
||||||
|
assert(df.select("b").first() === Row(struct))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test nested struct type") {
|
||||||
|
withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
|
||||||
|
val innerStruct = Row(1, "abcd")
|
||||||
|
val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg")
|
||||||
|
val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct)))
|
||||||
|
|
||||||
|
val schema = new StructType()
|
||||||
|
.add("a", IntegerType)
|
||||||
|
.add("b",
|
||||||
|
new StructType()
|
||||||
|
.add("b1", IntegerType)
|
||||||
|
.add("b2", LongType)
|
||||||
|
.add("b3", FloatType)
|
||||||
|
.add("b4", DoubleType)
|
||||||
|
.add("b5", new StructType()
|
||||||
|
.add("b5a", IntegerType)
|
||||||
|
.add("b5b", StringType))
|
||||||
|
.add("b6", StringType))
|
||||||
|
|
||||||
|
val df = sqlContext.createDataFrame(data, schema)
|
||||||
|
assert(df.select("b").first() === Row(outerStruct))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -23,7 +23,7 @@ import org.apache.spark.sql.execution.expressions.{SparkPartitionID, Monotonical
|
||||||
|
|
||||||
class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
|
class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
test("MonotonicallyIncreasingID") {
|
test("MonotonicallyIncreasingID") {
|
||||||
checkEvaluation(MonotonicallyIncreasingID(), 0)
|
checkEvaluation(MonotonicallyIncreasingID(), 0L)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("SparkPartitionID") {
|
test("SparkPartitionID") {
|
||||||
|
|
Loading…
Reference in a new issue