[SPARK-35380][SQL] Loading SparkSessionExtensions from ServiceLoader
### What changes were proposed in this pull request? In https://github.com/yaooqinn/itachi/issues/8, we had a discussion about the current extension injection for the spark session. We've agreed that the current way is not that convenient for both third-party developers and end-users. It's much simple if third-party developers can provide a resource file that contains default extensions for Spark to load ahead ### Why are the changes needed? better use experience ### Does this PR introduce _any_ user-facing change? no, dev change ### How was this patch tested? new tests Closes #32515 from yaooqinn/SPARK-35380. Authored-by: Kent Yao <yao@apache.org> Signed-off-by: Kent Yao <yao@apache.org>
This commit is contained in:
parent
dd5464976f
commit
51815430b2
|
@ -0,0 +1 @@
|
|||
org.apache.spark.examples.extensions.SessionExtensionsWithLoader
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.extensions
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{CurrentDate, Expression, RuntimeReplaceable, SubtractDates}
|
||||
|
||||
/**
|
||||
* How old are you in days?
|
||||
*/
|
||||
case class AgeExample(birthday: Expression, child: Expression) extends RuntimeReplaceable {
|
||||
|
||||
def this(birthday: Expression) = this(birthday, SubtractDates(CurrentDate(), birthday))
|
||||
override def exprsReplaced: Seq[Expression] = Seq(birthday)
|
||||
|
||||
override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.extensions
|
||||
|
||||
import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider}
|
||||
import org.apache.spark.sql.catalyst.FunctionIdentifier
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
|
||||
|
||||
class SessionExtensionsWithLoader extends SparkSessionExtensionsProvider {
|
||||
override def apply(v1: SparkSessionExtensions): Unit = {
|
||||
v1.injectFunction(
|
||||
(new FunctionIdentifier("age_two"),
|
||||
new ExpressionInfo(classOf[AgeExample].getName,
|
||||
"age_two"), (children: Seq[Expression]) => new AgeExample(children.head)))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.extensions
|
||||
|
||||
import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider}
|
||||
import org.apache.spark.sql.catalyst.FunctionIdentifier
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
|
||||
|
||||
class SessionExtensionsWithoutLoader extends SparkSessionExtensionsProvider {
|
||||
override def apply(v1: SparkSessionExtensions): Unit = {
|
||||
v1.injectFunction(
|
||||
(new FunctionIdentifier("age_one"),
|
||||
new ExpressionInfo(classOf[AgeExample].getName,
|
||||
"age_one"), (children: Seq[Expression]) => new AgeExample(children.head)))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.extensions
|
||||
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
/**
|
||||
* [[SessionExtensionsWithLoader]] is registered in
|
||||
* src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider
|
||||
*
|
||||
* [[SessionExtensionsWithoutLoader]] is registered via spark.sql.extensions
|
||||
*/
|
||||
object SparkSessionExtensionsTest {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val spark = SparkSession
|
||||
.builder()
|
||||
.appName("SparkSessionExtensionsTest")
|
||||
.config("spark.sql.extensions", classOf[SessionExtensionsWithoutLoader].getName)
|
||||
.getOrCreate()
|
||||
spark.sql("SELECT age_one('2018-11-17'), age_two('2018-11-17')").show()
|
||||
}
|
||||
}
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql
|
||||
|
||||
import java.io.Closeable
|
||||
import java.util.UUID
|
||||
import java.util.{ServiceLoader, UUID}
|
||||
import java.util.concurrent.TimeUnit._
|
||||
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
|
||||
|
||||
|
@ -949,6 +949,7 @@ object SparkSession extends Logging {
|
|||
// Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
|
||||
}
|
||||
|
||||
loadExtensions(extensions)
|
||||
applyExtensions(
|
||||
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
|
||||
extensions)
|
||||
|
@ -1203,4 +1204,22 @@ object SparkSession extends Logging {
|
|||
}
|
||||
extensions
|
||||
}
|
||||
|
||||
/**
|
||||
* Load extensions from [[ServiceLoader]] and use them
|
||||
*/
|
||||
private def loadExtensions(extensions: SparkSessionExtensions): Unit = {
|
||||
val loader = ServiceLoader.load(classOf[SparkSessionExtensionsProvider],
|
||||
Utils.getContextOrSparkClassLoader)
|
||||
val loadedExts = loader.iterator()
|
||||
|
||||
while (loadedExts.hasNext) {
|
||||
try {
|
||||
val ext = loadedExts.next()
|
||||
ext(extensions)
|
||||
} catch {
|
||||
case e: Throwable => logWarning("Failed to load session extension", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
|
|||
* {{{
|
||||
* SparkSession.builder()
|
||||
* .master("...")
|
||||
* .config("spark.sql.extensions", "org.example.MyExtensions")
|
||||
* .config("spark.sql.extensions", "org.example.MyExtensions,org.example.YourExtensions")
|
||||
* .getOrCreate()
|
||||
*
|
||||
* class MyExtensions extends Function1[SparkSessionExtensions, Unit] {
|
||||
|
@ -84,6 +84,15 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
|
|||
* }
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* class YourExtensions extends SparkSessionExtensionsProvider {
|
||||
* override def apply(extensions: SparkSessionExtensions): Unit = {
|
||||
* extensions.injectResolutionRule { session =>
|
||||
* ...
|
||||
* }
|
||||
* extensions.injectFunction(...)
|
||||
* }
|
||||
* }
|
||||
* }}}
|
||||
*
|
||||
* Note that none of the injected builders should assume that the [[SparkSession]] is fully
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.annotation.{DeveloperApi, Since, Unstable}
|
||||
|
||||
// scalastyle:off line.size.limit
|
||||
/**
|
||||
* :: Unstable ::
|
||||
*
|
||||
* Base trait for implementations used by [[SparkSessionExtensions]]
|
||||
*
|
||||
*
|
||||
* For example, now we have an external function named `Age` to register as an extension for SparkSession:
|
||||
*
|
||||
*
|
||||
* {{{
|
||||
* package org.apache.spark.examples.extensions
|
||||
*
|
||||
* import org.apache.spark.sql.catalyst.expressions.{CurrentDate, Expression, RuntimeReplaceable, SubtractDates}
|
||||
*
|
||||
* case class Age(birthday: Expression, child: Expression) extends RuntimeReplaceable {
|
||||
*
|
||||
* def this(birthday: Expression) = this(birthday, SubtractDates(CurrentDate(), birthday))
|
||||
* override def exprsReplaced: Seq[Expression] = Seq(birthday)
|
||||
* override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
|
||||
* }
|
||||
* }}}
|
||||
*
|
||||
* We need to create our extension which inherits [[SparkSessionExtensionsProvider]]
|
||||
* Example:
|
||||
*
|
||||
* {{{
|
||||
* package org.apache.spark.examples.extensions
|
||||
*
|
||||
* import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider}
|
||||
* import org.apache.spark.sql.catalyst.FunctionIdentifier
|
||||
* import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
|
||||
*
|
||||
* class MyExtensions extends SparkSessionExtensionsProvider {
|
||||
* override def apply(v1: SparkSessionExtensions): Unit = {
|
||||
* v1.injectFunction(
|
||||
* (new FunctionIdentifier("age"),
|
||||
* new ExpressionInfo(classOf[Age].getName, "age"),
|
||||
* (children: Seq[Expression]) => new Age(children.head)))
|
||||
* }
|
||||
* }
|
||||
* }}}
|
||||
*
|
||||
* Then, we can inject `MyExtensions` in three ways,
|
||||
* <ul>
|
||||
* <li>withExtensions of [[SparkSession.Builder]]</li>
|
||||
* <li>Config - spark.sql.extensions</li>
|
||||
* <li>[[java.util.ServiceLoader]] - Add to src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider</li>
|
||||
* </ul>
|
||||
*
|
||||
* @see [[SparkSessionExtensions]]
|
||||
* @see [[SparkSession.Builder]]
|
||||
* @see [[java.util.ServiceLoader]]
|
||||
*
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@DeveloperApi
|
||||
@Unstable
|
||||
@Since("3.2.0")
|
||||
trait SparkSessionExtensionsProvider extends Function1[SparkSessionExtensions, Unit]
|
||||
// scalastyle:on line.size.limit
|
|
@ -0,0 +1 @@
|
|||
org.apache.spark.sql.YourExtensions
|
|
@ -46,8 +46,8 @@ import org.apache.spark.unsafe.types.UTF8String
|
|||
* Test cases for the [[SparkSessionExtensions]].
|
||||
*/
|
||||
class SparkSessionExtensionSuite extends SparkFunSuite {
|
||||
type ExtensionsBuilder = SparkSessionExtensions => Unit
|
||||
private def create(builder: ExtensionsBuilder): Seq[ExtensionsBuilder] = Seq(builder)
|
||||
private def create(
|
||||
builder: SparkSessionExtensionsProvider): Seq[SparkSessionExtensionsProvider] = Seq(builder)
|
||||
|
||||
private def stop(spark: SparkSession): Unit = {
|
||||
spark.stop()
|
||||
|
@ -55,7 +55,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
|
|||
SparkSession.clearDefaultSession()
|
||||
}
|
||||
|
||||
private def withSession(builders: Seq[ExtensionsBuilder])(f: SparkSession => Unit): Unit = {
|
||||
private def withSession(
|
||||
builders: Seq[SparkSessionExtensionsProvider])(f: SparkSession => Unit): Unit = {
|
||||
val builder = SparkSession.builder().master("local[1]")
|
||||
builders.foreach(builder.withExtensions)
|
||||
val spark = builder.getOrCreate()
|
||||
|
@ -355,6 +356,20 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
|
|||
stop(session)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-35380: Loading extensions from ServiceLoader") {
|
||||
val builder = SparkSession.builder().master("local[1]")
|
||||
|
||||
Seq(None, Some(classOf[YourExtensions].getName)).foreach { ext =>
|
||||
ext.foreach(builder.config(SPARK_SESSION_EXTENSIONS.key, _))
|
||||
val session = builder.getOrCreate()
|
||||
try {
|
||||
assert(session.sql("select get_fake_app_name()").head().getString(0) === "Fake App Name")
|
||||
} finally {
|
||||
stop(session)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
|
||||
|
@ -959,3 +974,16 @@ class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) {
|
|||
e.injectFunction(MyExtensions2Duplicate.myFunction)
|
||||
}
|
||||
}
|
||||
|
||||
class YourExtensions extends SparkSessionExtensionsProvider {
|
||||
val getAppName = (FunctionIdentifier("get_fake_app_name"),
|
||||
new ExpressionInfo(
|
||||
"zzz.zzz.zzz",
|
||||
"",
|
||||
"get_fake_app_name"),
|
||||
(_: Seq[Expression]) => Literal("Fake App Name"))
|
||||
|
||||
override def apply(v1: SparkSessionExtensions): Unit = {
|
||||
v1.injectFunction(getAppName)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue