[SPARK-22370][SQL][PYSPARK] Config values should be captured in Driver.
## What changes were proposed in this pull request? `ArrowEvalPythonExec` and `FlatMapGroupsInPandasExec` are refering config values of `SQLConf` in function for `mapPartitions`/`mapPartitionsInternal`, but we should capture them in Driver. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN <ueshin@databricks.com> Closes #19587 from ueshin/issues/SPARK-22370.
This commit is contained in:
parent
683ffe0620
commit
4c5269f1aa
|
@ -3476,6 +3476,26 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
|
|||
expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
|
||||
self.assertEquals(expected, ts)
|
||||
|
||||
def test_vectorized_udf_check_config(self):
|
||||
from pyspark.sql.functions import pandas_udf, col
|
||||
orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
|
||||
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
|
||||
try:
|
||||
df = self.spark.range(10, numPartitions=1)
|
||||
|
||||
@pandas_udf(returnType=LongType())
|
||||
def check_records_per_batch(x):
|
||||
self.assertTrue(x.size <= 3)
|
||||
return x
|
||||
|
||||
result = df.select(check_records_per_batch(col("id")))
|
||||
self.assertEquals(df.collect(), result.collect())
|
||||
finally:
|
||||
if orig_value is None:
|
||||
self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
|
||||
else:
|
||||
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)
|
||||
|
||||
|
||||
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
|
||||
class GroupbyApplyTests(ReusedPySparkTestCase):
|
||||
|
|
|
@ -25,6 +25,12 @@ import org.apache.spark.sql.types.{DataType, StructType}
|
|||
abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] {
|
||||
self: PlanType =>
|
||||
|
||||
/**
|
||||
* The active config object within the current scope.
|
||||
* Note that if you want to refer config values during execution, you have to capture them
|
||||
* in Driver and use the captured values in Executors.
|
||||
* See [[SQLConf.get]] for more information.
|
||||
*/
|
||||
def conf: SQLConf = SQLConf.get
|
||||
|
||||
def output: Seq[Attribute]
|
||||
|
|
|
@ -61,6 +61,9 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
|
|||
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
|
||||
extends EvalPythonExec(udfs, output, child) {
|
||||
|
||||
private val batchSize = conf.arrowMaxRecordsPerBatch
|
||||
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
||||
|
||||
protected override def evaluate(
|
||||
funcs: Seq[ChainedPythonFunctions],
|
||||
bufferSize: Int,
|
||||
|
@ -73,13 +76,12 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
|
|||
val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
|
||||
.map { case (attr, i) => attr.withName(s"_$i") })
|
||||
|
||||
val batchSize = conf.arrowMaxRecordsPerBatch
|
||||
// DO NOT use iter.grouped(). See BatchIterator.
|
||||
val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter)
|
||||
|
||||
val columnarBatchIter = new ArrowPythonRunner(
|
||||
funcs, bufferSize, reuseWorker,
|
||||
PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone)
|
||||
PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone)
|
||||
.compute(batchIter, context.partitionId(), context)
|
||||
|
||||
new Iterator[InternalRow] {
|
||||
|
|
|
@ -77,6 +77,7 @@ case class FlatMapGroupsInPandasExec(
|
|||
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
|
||||
val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
|
||||
val schema = StructType(child.schema.drop(groupingAttributes.length))
|
||||
val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
||||
|
||||
inputRDD.mapPartitionsInternal { iter =>
|
||||
val grouped = if (groupingAttributes.isEmpty) {
|
||||
|
@ -94,7 +95,7 @@ case class FlatMapGroupsInPandasExec(
|
|||
|
||||
val columnarBatchIter = new ArrowPythonRunner(
|
||||
chainedFunc, bufferSize, reuseWorker,
|
||||
PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, conf.sessionLocalTimeZone)
|
||||
PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, sessionLocalTimeZone)
|
||||
.compute(grouped, context.partitionId(), context)
|
||||
|
||||
columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))
|
||||
|
|
Loading…
Reference in a new issue