diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index aec88c9241..c4b7f8490a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -103,7 +103,9 @@ public final class UnsafeRow extends BaseMutableRow { IntegerType, LongType, FloatType, - DoubleType + DoubleType, + DateType, + TimestampType }))); // We support get() on a superset of the types for which we support set(): @@ -331,8 +333,6 @@ public final class UnsafeRow extends BaseMutableRow { return getUTF8String(i).toString(); } - - @Override public InternalRow copy() { throw new UnsupportedOperationException(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 5c92f41c63..72f740ecae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -120,6 +122,8 @@ private object UnsafeColumnWriter { case FloatType => FloatUnsafeColumnWriter case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter + case DateType => IntUnsafeColumnWriter + case TimestampType => LongUnsafeColumnWriter case t => throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 534dac1f92..1098962ddc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -197,9 +197,10 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String) { + override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = UTF8String.fromString(value) } + override def setNullAt(i: Int): Unit = { values(i) = null } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 577c7a0de0..721ef8a226 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} import java.util.Arrays import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -74,6 +76,34 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.getString(2) should be ("World") } + test("basic conversion with primitive, string, date and timestamp types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setString(1, "Hello") + row.update(2, DateUtils.fromJavaDate(Date.valueOf("1970-01-01"))) + row.update(3, DateUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) + + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (8 * 4) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getString(1) should be ("Hello") + // Date is represented as Int in unsafeRow + DateUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("1970-01-01")) + // Timestamp is represented as Long in unsafeRow + DateUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be + (Timestamp.valueOf("2015-05-08 08:10:25")) + } + test("null handling") { val fieldTypes: Array[DataType] = Array( NullType,