diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1ea596dddf..3935f7b321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -392,7 +392,7 @@ class SQLContext(@transient val sparkContext: SparkContext) SparkPlan.currentContext.set(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes - val rowRDD = RDDConversions.productToRowRdd(rdd, schema) + val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index a500269f3c..f931dc95ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -21,9 +21,9 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} /** @@ -31,26 +31,19 @@ import org.apache.spark.sql.{Row, SQLContext} */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[Row] = { data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType)) - val schemaFields = schema.fields.toArray - val converters = schemaFields.map { - f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + val numColumns = outputTypes.length + val mutableRow = new GenericMutableRow(numColumns) + val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) + iterator.map { r => + var i = 0 + while (i < numColumns) { + mutableRow(i) = converters(i)(r.productElement(i)) + i += 1 } - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = converters(i)(r.productElement(i)) - i += 1 - } - mutableRow - } + mutableRow } } } @@ -58,26 +51,19 @@ object RDDConversions { /** * Convert the objects inside Row into the types Catalyst expected. */ - def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[Row] = { data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) - val schemaFields = schema.fields.toArray - val converters = schemaFields.map { - f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + val numColumns = outputTypes.length + val mutableRow = new GenericMutableRow(numColumns) + val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) + iterator.map { r => + var i = 0 + while (i < numColumns) { + mutableRow(i) = converters(i)(r(i)) + i += 1 } - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = converters(i)(r(i)) - i += 1 - } - mutableRow - } + mutableRow } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index dacd967cff..c6a4dabbab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -309,7 +309,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { output: Seq[Attribute], rdd: RDD[Row]): SparkPlan = { val converted = if (relation.needConversion) { - execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { rdd } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index de907846b9..0f959b3d0b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputForma import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLContext} @@ -108,7 +109,10 @@ class SimpleTextRelation( sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => Row(record.split(",").zip(fields).map { case (value, dataType) => - Cast(Literal(value), dataType).eval() + // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) + val catalystValue = Cast(Literal(value), dataType).eval() + // Here we're converting Catalyst values to Scala values to test `needsConversion` + CatalystTypeConverters.convertToScala(catalystValue, dataType) }: _*) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 70328e1ef8..7c02d563f8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -76,6 +76,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { df.filter('a > 1 && 'p1 < 2).select('b, 'p1), for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1)) + // Project many copies of columns with different types (reproduction for SPARK-7858) + checkAnswer( + df.filter('a > 1 && 'p1 < 2).select('b, 'b, 'b, 'b, 'p1, 'p1, 'p1, 'p1), + for (i <- 2 to 3; _ <- Seq("foo", "bar")) + yield Row(s"val_$i", s"val_$i", s"val_$i", s"val_$i", 1, 1, 1, 1)) + // Self-join df.registerTempTable("t") withTempTable("t") {