diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index e8266dd401..6212e8f48c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -101,16 +101,16 @@ object TypeUtils { } def failWithIntervalType(dataType: DataType): Unit = { - dataType match { - case CalendarIntervalType => - throw new AnalysisException("Cannot use interval type in the table schema.") - case ArrayType(et, _) => failWithIntervalType(et) - case MapType(kt, vt, _) => - failWithIntervalType(kt) - failWithIntervalType(vt) - case s: StructType => s.foreach(f => failWithIntervalType(f.dataType)) - case u: UserDefinedType[_] => failWithIntervalType(u.sqlType) - case _ => + invokeOnceForInterval(dataType) { + throw new AnalysisException("Cannot use interval type in the table schema.") } } + + def invokeOnceForInterval(dataType: DataType)(f: => Unit): Unit = { + def isInterval(dataType: DataType): Boolean = dataType match { + case CalendarIntervalType | DayTimeIntervalType | YearMonthIntervalType => true + case _ => false + } + if (dataType.existsRecursively(isInterval)) f + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 639c1e30b0..b3de98dd0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils} import org.apache.spark.sql.connector.catalog.TableProvider import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.SparkPlan @@ -50,7 +50,7 @@ import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, Tex import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{CalendarIntervalType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.{HadoopFSUtils, ThreadUtils, Utils} @@ -510,10 +510,7 @@ case class DataSource( physicalPlan: SparkPlan, metrics: Map[String, SQLMetric]): BaseRelation = { val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames) - if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { - throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError() - } - + disallowWritingIntervals(outputColumns.map(_.dataType)) providingInstance() match { case dataSource: CreatableRelationProvider => dataSource.createRelation( @@ -547,10 +544,7 @@ case class DataSource( * Returns a logical plan to write the given [[LogicalPlan]] out to this [[DataSource]]. */ def planForWriting(mode: SaveMode, data: LogicalPlan): LogicalPlan = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { - throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError() - } - + disallowWritingIntervals(data.schema.map(_.dataType)) providingInstance() match { case dataSource: CreatableRelationProvider => SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) @@ -579,6 +573,12 @@ case class DataSource( DataSource.checkAndGlobPathIfNecessary(allPaths.toSeq, newHadoopConfiguration(), checkEmptyGlobPath, checkFilesExist, enableGlobbing = globPaths) } + + private def disallowWritingIntervals(dataTypes: Seq[DataType]): Unit = { + dataTypes.foreach(TypeUtils.invokeOnceForInterval(_) { + throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError() + }) + } } object DataSource extends Logging {