From 4deeed17c4847f212a4fa1a8685cfe8a12179263 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 7 Jul 2014 17:04:02 -0700 Subject: [PATCH] [SPARK-2386] [SQL] RowWriteSupport should use the exact types to cast. When execute `saveAsParquetFile` with non-primitive type, `RowWriteSupport` uses wrong type `Int` for `ByteType` and `ShortType`. Author: Takuya UESHIN Closes #1315 from ueshin/issues/SPARK-2386 and squashes the following commits: 20d89ec [Takuya UESHIN] Use None instead of null. bd88741 [Takuya UESHIN] Add a test. 323d1d2 [Takuya UESHIN] Modify RowWriteSupport to use the exact types to cast. --- .../sql/parquet/ParquetTableSupport.scala | 4 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 40 ++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index bfcbdeb34a..9cd5dc5bbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -192,9 +192,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { ) ) case IntegerType => writer.addInteger(value.asInstanceOf[Int]) - case ShortType => writer.addInteger(value.asInstanceOf[Int]) + case ShortType => writer.addInteger(value.asInstanceOf[Short]) case LongType => writer.addLong(value.asInstanceOf[Long]) - case ByteType => writer.addInteger(value.asInstanceOf[Int]) + case ByteType => writer.addInteger(value.asInstanceOf[Byte]) case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 2ca0c1cdcb..dbf315947f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -67,6 +67,19 @@ case class AllDataTypes( byteField: Byte, booleanField: Boolean) +case class AllDataTypesWithNonPrimitiveType( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean, + array: Seq[Int], + map: Map[Int, String], + nested: Nested) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -119,6 +132,31 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } } + test("Read/Write All Types with non-primitive type") { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val range = (0 to 255) + TestSQLContext.sparkContext.parallelize(range) + .map(x => AllDataTypesWithNonPrimitiveType( + s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, + Seq(x), Map(x -> s"$x"), Nested(x, s"$x"))) + .saveAsParquetFile(tempDir) + val result = parquetFile(tempDir).collect() + range.foreach { + i => + assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") + assert(result(i).getInt(1) === i) + assert(result(i).getLong(2) === i.toLong) + assert(result(i).getFloat(3) === i.toFloat) + assert(result(i).getDouble(4) === i.toDouble) + assert(result(i).getShort(5) === i.toShort) + assert(result(i).getByte(6) === i.toByte) + assert(result(i).getBoolean(7) === (i % 2 == 0)) + assert(result(i)(8) === Seq(i)) + assert(result(i)(9) === Map(i -> s"$i")) + assert(result(i)(10) === new GenericRow(Array[Any](i, s"$i"))) + } + } + test("self-join parquet files") { val x = ParquetTestData.testData.as('x) val y = ParquetTestData.testData.as('y) @@ -298,7 +336,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("save and load case class RDD with Nones as parquet") { - val data = OptionalReflectData(null, null, null, null, null) + val data = OptionalReflectData(None, None, None, None, None) val rdd = sparkContext.parallelize(data :: Nil) val file = getTempFilePath("parquet")