diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index fb084dd13b..955fb4226f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -265,6 +265,8 @@ public final class UnsafeRow extends MutableRow { return getBinary(ordinal); } else if (dataType instanceof StringType) { return getUTF8String(ordinal); + } else if (dataType instanceof IntervalType) { + return getInterval(ordinal); } else if (dataType instanceof StructType) { return getStruct(ordinal, ((StructType) dataType).size()); } else { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 0ba31d3b9b..8fdd739960 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -81,6 +82,52 @@ public class UnsafeRowWriters { } } + /** + * Writer for struct type where the struct field is backed by an {@link UnsafeRow}. + * + * We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}. + * Non-UnsafeRow struct fields are handled directly in + * {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection} + * by generating the Java code needed to convert them into UnsafeRow. + */ + public static class StructWriter { + public static int getSize(InternalRow input) { + int numBytes = 0; + if (input instanceof UnsafeRow) { + numBytes = ((UnsafeRow) input).getSizeInBytes(); + } else { + // This is handled directly in GenerateUnsafeProjection. + throw new UnsupportedOperationException(); + } + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) { + int numBytes = 0; + final long offset = target.getBaseOffset() + cursor; + if (input instanceof UnsafeRow) { + final UnsafeRow row = (UnsafeRow) input; + numBytes = row.getSizeInBytes(); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the string to the variable length portion. + row.writeToMemory(target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + } else { + // This is handled directly in GenerateUnsafeProjection. + throw new UnsupportedOperationException(); + } + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + /** Writer for interval type. */ public static class IntervalWriter { @@ -96,5 +143,4 @@ public class UnsafeRowWriters { return 16; } } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 41a877f214..8304d4ccd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -50,7 +50,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case BinaryType => input.getBinary(ordinal) case IntervalType => input.getInterval(ordinal) case t: StructType => input.getStruct(ordinal, t.size) - case dataType => input.get(ordinal, dataType) + case _ => input.get(ordinal, dataType) } } } @@ -64,10 +64,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val value = ctx.getColumn("i", dataType, ordinal) s""" - boolean ${ev.isNull} = i.isNullAt($ordinal); - ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)}); + boolean ${ev.isNull} = i.isNullAt($ordinal); + $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 9d2161947b..3e87f72858 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -34,11 +34,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName + private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case t: AtomicType if !t.isInstanceOf[DecimalType] => true case _: IntervalType => true + case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case NullType => true case _ => false } @@ -55,15 +57,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ret = ev.primitive ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") - val bufferTerm = ctx.freshName("buffer") - ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];") - val cursorTerm = ctx.freshName("cursor") - val numBytesTerm = ctx.freshName("numBytes") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + val numBytes = ctx.freshName("numBytes") - val exprs = expressions.map(_.gen(ctx)) + val exprs = expressions.zipWithIndex.map { case (e, i) => + e.dataType match { + case st: StructType => + createCodeForStruct(ctx, e.gen(ctx), st) + case _ => + e.gen(ctx) + } + } val allExprs = exprs.map(_.code).mkString("\n") - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) val additionalSize = expressions.zipWithIndex.map { case (e, i) => e.dataType match { case StringType => @@ -72,6 +81,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" case IntervalType => s" + (${exprs(i).isNull} ? 0 : 16)" + case _: StructType => + s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))" case _ => "" } }.mkString("") @@ -81,11 +92,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case dt if ctx.isPrimitiveType(dt) => s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}" case StringType => - s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" + s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" case BinaryType => - s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" + s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" case IntervalType => - s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" + s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" + case t: StructType => + s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") @@ -99,24 +112,139 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $allExprs - int $numBytesTerm = $fixedSize $additionalSize; - if ($numBytesTerm > $bufferTerm.length) { - $bufferTerm = new byte[$numBytesTerm]; + int $numBytes = $fixedSize $additionalSize; + if ($numBytes > $buffer.length) { + $buffer = new byte[$numBytes]; } $ret.pointTo( - $bufferTerm, + $buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, ${expressions.size}, - $numBytesTerm); - int $cursorTerm = $fixedSize; - + $numBytes); + int $cursor = $fixedSize; $writers boolean ${ev.isNull} = false; """ } + /** + * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. + * + * This function also handles nested structs by recursively generating the code to do conversion. + * + * @param ctx code generation context + * @param input the input struct, identified by a [[GeneratedExpressionCode]] + * @param schema schema of the struct field + */ + // TODO: refactor createCode and this function to reduce code duplication. + private def createCodeForStruct( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + schema: StructType): GeneratedExpressionCode = { + + val isNull = input.isNull + val primitive = ctx.freshName("structConvert") + ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + + val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map { + case (dt, i) => dt match { + case st: StructType => + val nestedStructEv = GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + ) + createCodeForStruct(ctx, nestedStructEv, st) + case _ => + GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + ) + } + } + val allExprs = exprs.map(_.code).mkString("\n") + + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => + dt match { + case StringType => + s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + case BinaryType => + s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + case IntervalType => + s" + (${ev.isNull} ? 0 : 16)" + case _: StructType => + s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + case _ => "" + } + }.mkString("") + + val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => + val update = dt match { + case _ if ctx.isPrimitiveType(dt) => + s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}" + case StringType => + s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" + case BinaryType => + s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" + case IntervalType => + s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" + case t: StructType => + s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: $dt") + } + s""" + if (${exprs(i).isNull}) { + $primitive.setNullAt($i); + } else { + $update; + } + """ + }.mkString("\n ") + + // Note that we add a shortcut here for performance: if the input is already an UnsafeRow, + // just copy the bytes directly into our buffer space without running any conversion. + // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from + // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow. + val tmp = ctx.freshName("tmp") + val numBytes = ctx.freshName("numBytes") + val code = s""" + |${input.code} + |if (!${input.isNull}) { + | Object $tmp = (Object) ${input.primitive}; + | if ($tmp instanceof UnsafeRow) { + | $primitive = (UnsafeRow) $tmp; + | } else { + | $allExprs + | + | int $numBytes = $fixedSize $additionalSize; + | if ($numBytes > $buffer.length) { + | $buffer = new byte[$numBytes]; + | } + | + | $primitive.pointTo( + | $buffer, + | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + | ${exprs.size}, + | $numBytes); + | int $cursor = $fixedSize; + | + | $writers + | } + |} + """.stripMargin + + GeneratedExpressionCode(code, isNull, primitive) + } + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -159,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } """ - logDebug(s"code for ${expressions.mkString(",")}:\n$code") + logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 119168fa59..d8c9087ff5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -104,18 +104,19 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ }.mkString("\n") } override def prettyName: String = "struct" } + /** * Creates a struct with the given field names and values * @@ -168,14 +169,83 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { valExprs.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ }.mkString("\n") } override def prettyName: String = "named_struct" } + +/** + * Returns a Row containing the evaluation of all children expressions. This is a variant that + * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + */ +case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, children) + } + + override def prettyName: String = "struct_unsafe" +} + + +/** + * Creates a struct with the given field names and values. This is a variant that returns + * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, valExprs) + } + + override def prettyName: String = "named_struct_unsafe" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e7e5231d32..7773e098e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -170,6 +170,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(-7, 3), 2) checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) - checkEvaluation(Pmod(2L, Long.MaxValue), 2) + checkEvaluation(Pmod(2L, Long.MaxValue), 2L) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala index 648fbf5a4c..fa30fbe528 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -30,8 +30,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, ~1.toByte) - check(1000.toShort, ~1000.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, (~1.toByte).toByte) + check(1000.toShort, (~1000.toShort).toShort) check(1000000, ~1000000) check(123456789123L, ~123456789123L) @@ -45,8 +46,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte & 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort & 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort) check(1000000, 4, 1000000 & 4) check(123456789123L, 5L, 123456789123L & 5L) @@ -63,8 +65,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte | 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort | 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort) check(1000000, 4, 1000000 | 4) check(123456789123L, 5L, 123456789123L | 5L) @@ -81,8 +84,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte ^ 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort ^ 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort) check(1000000, 4, 1000000 ^ 4) check(123456789123L, 5L, 123456789123L ^ 5L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index ab0cdc857c..136368bf5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -114,7 +114,7 @@ trait ExpressionEvalHelper { val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") } } @@ -146,7 +146,8 @@ trait ExpressionEvalHelper { if (actual != expectedRow) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + fail("Incorrect Evaluation in codegen mode: " + + s"$expression, actual: $actual, expected: $expectedRow$input") } if (actual.copy() != expectedRow) { fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") @@ -163,12 +164,21 @@ trait ExpressionEvalHelper { expression) val unsafeRow = plan(inputRow) - // UnsafeRow cannot be compared with GenericInternalRow directly - val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow) - val expectedRow = InternalRow(expected) - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected) + val expectedRow = UnsafeProjection.create(Array(expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d88a02298c..314b85f126 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -363,7 +363,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Sort(sortExprs, global, child) => getSortOperator(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => - execution.Project(projectList, planLater(child)) :: Nil + // If unsafe mode is enabled and we support these data types in Unsafe, use the + // Tungsten project. Otherwise, use the normal project. + if (sqlContext.conf.unsafeEnabled && + UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { + execution.TungstenProject(projectList, planLater(child)) :: Nil + } else { + execution.Project(projectList, planLater(child)) :: Nil + } case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fe429d862a..b02e60dc85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -49,6 +49,31 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends override def outputOrdering: Seq[SortOrder] = child.outputOrdering } + +/** + * A variant of [[Project]] that returns [[UnsafeRow]]s. + */ +case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + this.transformAllExpressions { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + } + val project = UnsafeProjection.create(projectList, child.output) + iter.map(project) + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + + /** * :: DeveloperApi :: */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala new file mode 100644 index 0000000000..bf8ef9a97b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + +/** + * An end-to-end test suite specifically for testing Tungsten (Unsafe/CodeGen) mode. + * + * This is here for now so I can make sure Tungsten project is tested without refactoring existing + * end-to-end test infra. In the long run this should just go away. + */ +class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { + + override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ + + test("test simple types") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) + } + } + + test("test struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) + } + } + + test("test nested struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala index 99e11fd64b..1c5a2ed2c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.execution.expressions.{SparkPartitionID, Monotonical class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("MonotonicallyIncreasingID") { - checkEvaluation(MonotonicallyIncreasingID(), 0) + checkEvaluation(MonotonicallyIncreasingID(), 0L) } test("SparkPartitionID") {