diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 6ea2b22d7b..9c77025514 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -5160,26 +5160,25 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils): sys.stdout = prev def test_explain_hint(self): - with ps.option_context("compute.default_index_type", "sequence"): - psdf1 = ps.DataFrame( - {"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]}, - columns=["lkey", "value"], - ) - psdf2 = ps.DataFrame( - {"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]}, - columns=["rkey", "value"], - ) - merged = psdf1.merge(psdf2.spark.hint("broadcast"), left_on="lkey", right_on="rkey") - prev = sys.stdout - try: - out = StringIO() - sys.stdout = out - merged.spark.explain() - actual = out.getvalue().strip() + psdf1 = ps.DataFrame( + {"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]}, + columns=["lkey", "value"], + ) + psdf2 = ps.DataFrame( + {"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]}, + columns=["rkey", "value"], + ) + merged = psdf1.merge(psdf2.spark.hint("broadcast"), left_on="lkey", right_on="rkey") + prev = sys.stdout + try: + out = StringIO() + sys.stdout = out + merged.spark.explain() + actual = out.getvalue().strip() - self.assertTrue("Broadcast" in actual, actual) - finally: - sys.stdout = prev + self.assertTrue("Broadcast" in actual, actual) + finally: + sys.stdout = prev def test_mad(self): pdf = pd.DataFrame( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 63824af072..7b37891de2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -225,6 +225,10 @@ object DeduplicateRelations extends Rule[LogicalPlan] { if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) + case oldVersion @ AttachDistributedSequence(sequenceAttr, _) + if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => + Seq((oldVersion, oldVersion.copy(sequenceAttr = sequenceAttr.newInstance()))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f4d4470059..40b9d6554d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -794,6 +794,11 @@ object ColumnPruning extends Rule[LogicalPlan] { } a.copy(child = Expand(newProjects, newOutput, grandChild)) + // Prune and drop AttachDistributedSequence if the produced attribute is not referred. + case p @ Project(_, a @ AttachDistributedSequence(_, grandChild)) + if !p.references.contains(a.sequenceAttr) => + p.copy(child = prunedChild(grandChild, p.references)) + // Prunes the unused columns from child of `DeserializeToObject` case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) => d.copy(child = prunedChild(child, d.references)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index ba8352cf6a..af18540c56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -115,3 +115,20 @@ case class ArrowEvalPython( override protected def withNewChildInternal(newChild: LogicalPlan): ArrowEvalPython = copy(child = newChild) } + +/** + * A logical plan that adds a new long column with the name `name` that + * increases one by one. This is for 'distributed-sequence' default index + * in pandas API on Spark. + */ +case class AttachDistributedSequence( + sequenceAttr: Attribute, + child: LogicalPlan) extends UnaryNode { + + override val producedAttributes: AttributeSet = AttributeSet(sequenceAttr) + + override val output: Seq[Attribute] = sequenceAttr +: child.output + + override protected def withNewChildInternal(newChild: LogicalPlan): AttachDistributedSequence = + copy(child = newChild) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 4db58298e1..0655acbcb1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -452,5 +452,11 @@ class ColumnPruningSuite extends PlanTest { val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze comparePlans(optimized, expected) } - // todo: add more tests for column pruning + + test("SPARK-36559 Prune and drop distributed-sequence if the produced column is not referred") { + val input = LocalRelation('a.int, 'b.int, 'c.int) + val plan1 = AttachDistributedSequence('d.int, input).select('a) + val correctAnswer1 = Project(Seq('a), input).analyze + comparePlans(Optimize.execute(plan1.analyze), correctAnswer1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fb8620c7a6..fe84cc09e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3514,24 +3514,11 @@ class Dataset[T] private[sql]( * This is for 'distributed-sequence' default index in pandas API on Spark. */ private[sql] def withSequenceColumn(name: String) = { - val rdd: RDD[InternalRow] = - // Checkpoint the DataFrame to fix the partition ID. - localCheckpoint(false) - .queryExecution.toRdd.zipWithIndex().mapPartitions { iter => - val joinedRow = new JoinedRow - val unsafeRowWriter = - new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1) - - iter.map { case (row, id) => - // Writes to an UnsafeRow directly - unsafeRowWriter.reset() - unsafeRowWriter.write(0, id) - joinedRow(unsafeRowWriter.getRow, row) - } - } - - sparkSession.internalCreateDataFrame( - rdd, StructType(StructField(name, LongType, nullable = false) +: schema), isStreaming) + Dataset.ofRows( + sparkSession, + AttachDistributedSequence( + AttributeReference(name, LongType, nullable = false)(), + logicalPlan)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1e64582942..fc2898bf24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -712,6 +712,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { func, output, planLater(left), planLater(right)) :: Nil case logical.MapInPandas(func, output, child) => execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil + case logical.AttachDistributedSequence(attr, child) => + execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala new file mode 100644 index 0000000000..27bfb7f682 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} + +/** + * A physical plan that adds a new long column with `sequenceAttr` that + * increases one by one. This is for 'distributed-sequence' default index + * in pandas API on Spark. + */ +case class AttachDistributedSequenceExec( + sequenceAttr: Attribute, + child: SparkPlan) + extends UnaryExecNode { + + override def producedAttributes: AttributeSet = AttributeSet(sequenceAttr) + + override val output: Seq[Attribute] = sequenceAttr +: child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().map(_.copy()) + .localCheckpoint() // to avoid execute multiple jobs. zipWithIndex launches a Spark job. + .zipWithIndex().mapPartitions { iter => + val unsafeProj = UnsafeProjection.create(output, output) + val joinedRow = new JoinedRow + val unsafeRowWriter = + new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1) + + iter.map { case (row, id) => + // Writes to an UnsafeRow directly + unsafeRowWriter.reset() + unsafeRowWriter.write(0, id) + joinedRow(unsafeRowWriter.getRow, row) + }.map(unsafeProj) + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): AttachDistributedSequenceExec = + copy(child = newChild) +}