diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index 717709e4f9..deed45d273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule /** * :: Experimental :: @@ -42,4 +44,7 @@ class ExperimentalMethods protected[sql](sqlContext: SQLContext) { @Experimental var extraStrategies: Seq[Strategy] = Nil + @Experimental + var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 61c74f8340..6721d9c407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} +import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ @@ -202,7 +202,7 @@ class SQLContext private[sql]( } @transient - protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this) @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala new file mode 100644 index 0000000000..edaf3b36aa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.optimizer._ + +class SparkOptimizer(val sqlContext: SQLContext) + extends Optimizer { + override def batches: Seq[Batch] = super.batches :+ Batch( + "User Provided Optimizers", FixedPoint(100), sqlContext.experimental.extraOptimizations: _*) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 1994dacfc4..14b9448d26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -18,9 +18,15 @@ package org.apache.spark.sql import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ + object DummyRule extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + test("getOrCreate instantiates SQLContext") { val sqlContext = SQLContext.getOrCreate(sc) assert(sqlContext != null, "SQLContext.getOrCreate returned null") @@ -65,4 +71,10 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ session2.sql("select myadd(1, 2)").explain() } } + + test("Catalyst optimization passes are modifiable at runtime") { + val sqlContext = SQLContext.getOrCreate(sc) + sqlContext.experimental.extraOptimizations = Seq(DummyRule) + assert(sqlContext.optimizer.batches.flatMap(_.rules).contains(DummyRule)) + } }