diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index 148a66cc52..ffd77c5ff6 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -17,11 +17,19 @@ package org.apache.spark.sql.avro +import java.io.ByteArrayOutputStream + import scala.collection.JavaConverters._ +import org.apache.avro.Schema +import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, GenericRecordBuilder} +import org.apache.avro.io.EncoderFactory + import org.apache.spark.SparkException import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.execution.LocalTableScanExec +import org.apache.spark.sql.functions.{col, struct} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} class AvroFunctionsSuite extends QueryTest with SharedSQLContext with SQLTestUtils { @@ -111,4 +119,38 @@ class AvroFunctionsSuite extends QueryTest with SharedSQLContext with SQLTestUti .select(from_avro($"avro", avroTypeArrStruct).as("array")) checkAnswer(dfOne, readBackOne) } + + test("SPARK-27798: from_avro produces same value when converted to local relation") { + val simpleSchema = + """ + |{ + | "type": "record", + | "name" : "Payload", + | "fields" : [ {"name" : "message", "type" : "string" } ] + |} + """.stripMargin + + def generateBinary(message: String, avroSchema: String): Array[Byte] = { + val schema = new Schema.Parser().parse(avroSchema) + val out = new ByteArrayOutputStream() + val writer = new GenericDatumWriter[GenericRecord](schema) + val encoder = EncoderFactory.get().binaryEncoder(out, null) + val rootRecord = new GenericRecordBuilder(schema).set("message", message).build() + writer.write(rootRecord, encoder) + encoder.flush() + out.toByteArray + } + + // This bug is hit when the rule `ConvertToLocalRelation` is run. But the rule was excluded + // in `SharedSparkSession`. + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "") { + val df = Seq("one", "two", "three", "four").map(generateBinary(_, simpleSchema)) + .toDF() + .withColumn("value", + functions.from_avro(col("value"), simpleSchema)) + + assert(df.queryExecution.executedPlan.isInstanceOf[LocalTableScanExec]) + assert(df.collect().map(_.get(0)) === Seq(Row("one"), Row("two"), Row("three"), Row("four"))) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5b59ac7d2a..8c52ff9e9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1420,9 +1420,9 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Project(projectList, LocalRelation(output, data, isStreaming)) if !projectList.exists(hasUnevaluableExpr) => - val projection = new InterpretedProjection(projectList, output) + val projection = new InterpretedMutableProjection(projectList, output) projection.initialize(0) - LocalRelation(projectList.map(_.toAttribute), data.map(projection), isStreaming) + LocalRelation(projectList.map(_.toAttribute), data.map(projection(_).copy()), isStreaming) case Limit(IntegerLiteral(limit), LocalRelation(output, data, isStreaming)) => LocalRelation(output, data.take(limit), isStreaming) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 0c015f88e1..43579d4c90 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -21,10 +21,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{LessThan, Literal} +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, LessThan, Literal, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{DataType, StructType} class ConvertToLocalRelationSuite extends PlanTest { @@ -70,4 +72,36 @@ class ConvertToLocalRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("SPARK-27798: Expression reusing output shouldn't override values in local relation") { + val testRelation = LocalRelation( + LocalRelation('a.int).output, + InternalRow(1) :: InternalRow(2) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation('a.struct('a1.int)).output, + InternalRow(InternalRow(1)) :: InternalRow(InternalRow(2)) :: Nil) + + val projected = testRelation.select(ExprReuseOutput(UnresolvedAttribute("a")).as("a")) + val optimized = Optimize.execute(projected.analyze) + + comparePlans(optimized, correctAnswer) + } +} + + +// Dummy expression used for testing. It reuses output row. Assumes child expr outputs an integer. +case class ExprReuseOutput(child: Expression) extends UnaryExpression { + override def dataType: DataType = StructType.fromDDL("a1 int") + override def nullable: Boolean = true + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("Should not trigger codegen") + + private val row: InternalRow = new GenericInternalRow(1) + + override def eval(input: InternalRow): Any = { + row.update(0, child.eval(input)) + row + } }