diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9e9a856286..b7884f9b60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -444,8 +444,43 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => - i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) + case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => + val table = lookupTableFromCatalog(u) + // adding the table's partitions or validate the query's partition info + table match { + case relation: CatalogRelation if relation.catalogTable.partitionColumns.nonEmpty => + val tablePartitionNames = relation.catalogTable.partitionColumns.map(_.name) + if (parts.keys.nonEmpty) { + // the query's partitioning must match the table's partitioning + // this is set for queries like: insert into ... partition (one = "a", two = ) + // TODO: add better checking to pre-inserts to avoid needing this here + if (tablePartitionNames.size != parts.keySet.size) { + throw new AnalysisException( + s"""Requested partitioning does not match the ${u.tableIdentifier} table: + |Requested partitions: ${parts.keys.mkString(",")} + |Table partitions: ${tablePartitionNames.mkString(",")}""".stripMargin) + } + // Assume partition columns are correctly placed at the end of the child's output + i.copy(table = EliminateSubqueryAliases(table)) + } else { + // Set up the table's partition scheme with all dynamic partitions by moving partition + // columns to the end of the column list, in partition order. + val (inputPartCols, columns) = child.output.partition { attr => + tablePartitionNames.contains(attr.name) + } + // All partition columns are dynamic because this InsertIntoTable had no partitioning + val partColumns = tablePartitionNames.map { name => + inputPartCols.find(_.name == name).getOrElse( + throw new AnalysisException(s"Cannot find partition column $name")) + } + i.copy( + table = EliminateSubqueryAliases(table), + partition = tablePartitionNames.map(_ -> None).toMap, + child = Project(columns ++ partColumns, child)) + } + case _ => + i.copy(table = EliminateSubqueryAliases(table)) + } case u: UnresolvedRelation => val table = u.tableIdentifier if (table.database.isDefined && conf.runSQLonFile && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 8b438e40e6..732b0d7919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -354,10 +354,23 @@ case class InsertIntoTable( override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty + private[spark] lazy val expectedColumns = { + if (table.output.isEmpty) { + None + } else { + val numDynamicPartitions = partition.values.count(_.isEmpty) + val (partitionColumns, dataColumns) = table.output + .partition(a => partition.keySet.contains(a.name)) + Some(dataColumns ++ partitionColumns.takeRight(numDynamicPartitions)) + } + } + assert(overwrite || !ifNotExists) - override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { - case (childAttr, tableAttr) => - DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) + override lazy val resolved: Boolean = childrenResolved && expectedColumns.forall { expected => + child.output.size == expected.size && child.output.zip(expected).forall { + case (childAttr, tableAttr) => + DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 73ccec2ee0..3805674d39 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -168,7 +168,15 @@ case class InsertIntoHiveTable( // All partition column names in the format of "//..." val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns") - val partitionColumnNames = Option(partitionColumns).map(_.split("/")).orNull + val partitionColumnNames = Option(partitionColumns).map(_.split("/")).getOrElse(Array.empty) + + // By this time, the partition map must match the table's partition columns + if (partitionColumnNames.toSet != partition.keySet) { + throw new SparkException( + s"""Requested partitioning does not match the ${table.tableName} table: + |Requested partitions: ${partition.keys.mkString(",")} + |Table partitions: ${table.partitionKeys.map(_.name).mkString(",")}""".stripMargin) + } // Validate partition spec if there exist any dynamic partitions if (numDynamicPartitions > 0) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index baf34d1cf0..52aba328de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -22,9 +22,11 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkException import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -32,11 +34,11 @@ case class TestData(key: Int, value: String) case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { +class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter + with SQLTestUtils { import hiveContext.implicits._ - import hiveContext.sql - val testData = hiveContext.sparkContext.parallelize( + override lazy val testData = hiveContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() before { @@ -213,4 +215,77 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE hiveTableWithStructValue") } + + test("Reject partitioning that does not match table") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + + intercept[AnalysisException] { + // cannot partition by 2 fields when there is only one in the table definition + data.write.partitionBy("part", "data").insertInto("partitioned") + } + } + } + + test("Test partition mode = strict") { + withSQLConf(("hive.exec.dynamic.partition.mode", "strict")) { + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + + intercept[SparkException] { + data.write.insertInto("partitioned") + } + } + } + + test("Detect table partitioning") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, data string, part string)") + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")).toDF() + + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + // this will pick up the output partitioning from the table definition + sqlContext.table("source").write.insertInto("partitioned") + + checkAnswer(sql("SELECT * FROM partitioned"), data.collect().toSeq) + } + } + + test("Detect table partitioning with correct partition order") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part2 string, part1 string, data string)") + val data = (1 to 10).map(i => (i, if ((i % 2) == 0) "even" else "odd", "p", s"data-$i")) + .toDF("id", "part2", "part1", "data") + + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + // the original data with part1 and part2 at the end + val expected = data.select("id", "data", "part1", "part2") + + sql( + """CREATE TABLE partitioned (id bigint, data string) + |PARTITIONED BY (part1 string, part2 string)""".stripMargin) + sqlContext.table("source").write.insertInto("partitioned") + + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } + + test("InsertIntoTable#resolved should include dynamic partitions") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data") + + val logical = InsertIntoTable(sqlContext.table("partitioned").logicalPlan, + Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false) + assert(!logical.resolved, "Should not resolve: missing partition data") + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 3bf0e84267..bbb775ef77 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -978,7 +978,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SET hive.exec.dynamic.partition.mode=strict") // Should throw when using strict dynamic partition mode without any static partition - intercept[SparkException] { + intercept[AnalysisException] { sql( """INSERT INTO TABLE dp_test PARTITION(dp) |SELECT key, value, key % 5 FROM src