[SPARK-22370][SQL][PYSPARK] Config values should be captured in Driver.

## What changes were proposed in this pull request?

`ArrowEvalPythonExec` and `FlatMapGroupsInPandasExec` are refering config values of `SQLConf` in function for `mapPartitions`/`mapPartitionsInternal`, but we should capture them in Driver.

## How was this patch tested?

Added a test and existing tests.

Author: Takuya UESHIN <ueshin@databricks.com>

Closes #19587 from ueshin/issues/SPARK-22370.
This commit is contained in:
Takuya UESHIN 2017-10-28 18:33:09 +01:00 committed by Wenchen Fan
parent 683ffe0620
commit 4c5269f1aa
4 changed files with 32 additions and 3 deletions

View file

@ -3476,6 +3476,26 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
self.assertEquals(expected, ts)
def test_vectorized_udf_check_config(self):
from pyspark.sql.functions import pandas_udf, col
orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
try:
df = self.spark.range(10, numPartitions=1)
@pandas_udf(returnType=LongType())
def check_records_per_batch(x):
self.assertTrue(x.size <= 3)
return x
result = df.select(check_records_per_batch(col("id")))
self.assertEquals(df.collect(), result.collect())
finally:
if orig_value is None:
self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
else:
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class GroupbyApplyTests(ReusedPySparkTestCase):

View file

@ -25,6 +25,12 @@ import org.apache.spark.sql.types.{DataType, StructType}
abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] {
self: PlanType =>
/**
* The active config object within the current scope.
* Note that if you want to refer config values during execution, you have to capture them
* in Driver and use the captured values in Executors.
* See [[SQLConf.get]] for more information.
*/
def conf: SQLConf = SQLConf.get
def output: Seq[Attribute]

View file

@ -61,6 +61,9 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, output, child) {
private val batchSize = conf.arrowMaxRecordsPerBatch
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
@ -73,13 +76,12 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
.map { case (attr, i) => attr.withName(s"_$i") })
val batchSize = conf.arrowMaxRecordsPerBatch
// DO NOT use iter.grouped(). See BatchIterator.
val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter)
val columnarBatchIter = new ArrowPythonRunner(
funcs, bufferSize, reuseWorker,
PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone)
PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone)
.compute(batchIter, context.partitionId(), context)
new Iterator[InternalRow] {

View file

@ -77,6 +77,7 @@ case class FlatMapGroupsInPandasExec(
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
val schema = StructType(child.schema.drop(groupingAttributes.length))
val sessionLocalTimeZone = conf.sessionLocalTimeZone
inputRDD.mapPartitionsInternal { iter =>
val grouped = if (groupingAttributes.isEmpty) {
@ -94,7 +95,7 @@ case class FlatMapGroupsInPandasExec(
val columnarBatchIter = new ArrowPythonRunner(
chainedFunc, bufferSize, reuseWorker,
PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, conf.sessionLocalTimeZone)
PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, sessionLocalTimeZone)
.compute(grouped, context.partitionId(), context)
columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))