spark-instrumented-optimizer/python/pyspark/tests/test_taskcontext.py
HyukjinKwon e11a24c1ba [SPARK-33371][PYTHON] Update setup.py and tests for Python 3.9
### What changes were proposed in this pull request?

This PR proposes to fix PySpark to officially support Python 3.9. The main codes already work. We should just note that we support Python 3.9.

Also, this PR fixes some minor fixes into the test codes.
- `Thread.isAlive` is removed in Python 3.9, and `Thread.is_alive` exists in Python 3.6+, see https://docs.python.org/3/whatsnew/3.9.html#removed
- Fixed `TaskContextTestsWithWorkerReuse.test_barrier_with_python_worker_reuse` and `TaskContextTests.test_barrier` to be less flaky. This becomes more flaky in Python 3.9 for some reasons.

NOTE that PyArrow does not support Python 3.9 yet.

### Why are the changes needed?

To officially support Python 3.9.

### Does this PR introduce _any_ user-facing change?

Yes, it officially supports Python 3.9.

### How was this patch tested?

Manually ran the tests:

```
$  ./run-tests --python-executable=python
Running PySpark tests. Output is in /.../spark/python/unit-tests.log
Will test against the following Python executables: ['python']
Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-resource', 'pyspark-sql', 'pyspark-streaming']
python python_implementation is CPython
python version is: Python 3.9.0
Starting test(python): pyspark.ml.tests.test_base
Starting test(python): pyspark.ml.tests.test_evaluation
Starting test(python): pyspark.ml.tests.test_algorithms
Starting test(python): pyspark.ml.tests.test_feature
Finished test(python): pyspark.ml.tests.test_base (12s)
Starting test(python): pyspark.ml.tests.test_image
Finished test(python): pyspark.ml.tests.test_evaluation (15s)
Starting test(python): pyspark.ml.tests.test_linalg
Finished test(python): pyspark.ml.tests.test_feature (25s)
Starting test(python): pyspark.ml.tests.test_param
Finished test(python): pyspark.ml.tests.test_image (17s)
Starting test(python): pyspark.ml.tests.test_persistence
Finished test(python): pyspark.ml.tests.test_param (17s)
Starting test(python): pyspark.ml.tests.test_pipeline
Finished test(python): pyspark.ml.tests.test_linalg (30s)
Starting test(python): pyspark.ml.tests.test_stat
Finished test(python): pyspark.ml.tests.test_pipeline (6s)
Starting test(python): pyspark.ml.tests.test_training_summary
Finished test(python): pyspark.ml.tests.test_stat (12s)
Starting test(python): pyspark.ml.tests.test_tuning
Finished test(python): pyspark.ml.tests.test_algorithms (68s)
Starting test(python): pyspark.ml.tests.test_wrapper
Finished test(python): pyspark.ml.tests.test_persistence (51s)
Starting test(python): pyspark.mllib.tests.test_algorithms
Finished test(python): pyspark.ml.tests.test_training_summary (33s)
Starting test(python): pyspark.mllib.tests.test_feature
Finished test(python): pyspark.ml.tests.test_wrapper (19s)
Starting test(python): pyspark.mllib.tests.test_linalg
Finished test(python): pyspark.mllib.tests.test_feature (26s)
Starting test(python): pyspark.mllib.tests.test_stat
Finished test(python): pyspark.mllib.tests.test_stat (22s)
Starting test(python): pyspark.mllib.tests.test_streaming_algorithms
Finished test(python): pyspark.mllib.tests.test_algorithms (53s)
Starting test(python): pyspark.mllib.tests.test_util
Finished test(python): pyspark.mllib.tests.test_linalg (54s)
Starting test(python): pyspark.sql.tests.test_arrow
Finished test(python): pyspark.sql.tests.test_arrow (0s) ... 61 tests were skipped
Starting test(python): pyspark.sql.tests.test_catalog
Finished test(python): pyspark.mllib.tests.test_util (11s)
Starting test(python): pyspark.sql.tests.test_column
Finished test(python): pyspark.sql.tests.test_catalog (16s)
Starting test(python): pyspark.sql.tests.test_conf
Finished test(python): pyspark.sql.tests.test_column (17s)
Starting test(python): pyspark.sql.tests.test_context
Finished test(python): pyspark.sql.tests.test_context (6s) ... 3 tests were skipped
Starting test(python): pyspark.sql.tests.test_dataframe
Finished test(python): pyspark.sql.tests.test_conf (11s)
Starting test(python): pyspark.sql.tests.test_datasources
Finished test(python): pyspark.sql.tests.test_datasources (19s)
Starting test(python): pyspark.sql.tests.test_functions
Finished test(python): pyspark.sql.tests.test_dataframe (35s) ... 3 tests were skipped
Starting test(python): pyspark.sql.tests.test_group
Finished test(python): pyspark.sql.tests.test_functions (32s)
Starting test(python): pyspark.sql.tests.test_pandas_cogrouped_map
Finished test(python): pyspark.sql.tests.test_pandas_cogrouped_map (1s) ... 15 tests were skipped
Starting test(python): pyspark.sql.tests.test_pandas_grouped_map
Finished test(python): pyspark.sql.tests.test_group (19s)
Starting test(python): pyspark.sql.tests.test_pandas_map
Finished test(python): pyspark.sql.tests.test_pandas_grouped_map (0s) ... 21 tests were skipped
Starting test(python): pyspark.sql.tests.test_pandas_udf
Finished test(python): pyspark.sql.tests.test_pandas_map (0s) ... 6 tests were skipped
Starting test(python): pyspark.sql.tests.test_pandas_udf_grouped_agg
Finished test(python): pyspark.sql.tests.test_pandas_udf (0s) ... 6 tests were skipped
Starting test(python): pyspark.sql.tests.test_pandas_udf_scalar
Finished test(python): pyspark.sql.tests.test_pandas_udf_grouped_agg (0s) ... 13 tests were skipped
Starting test(python): pyspark.sql.tests.test_pandas_udf_typehints
Finished test(python): pyspark.sql.tests.test_pandas_udf_scalar (0s) ... 50 tests were skipped
Starting test(python): pyspark.sql.tests.test_pandas_udf_window
Finished test(python): pyspark.sql.tests.test_pandas_udf_typehints (0s) ... 10 tests were skipped
Starting test(python): pyspark.sql.tests.test_readwriter
Finished test(python): pyspark.sql.tests.test_pandas_udf_window (0s) ... 14 tests were skipped
Starting test(python): pyspark.sql.tests.test_serde
Finished test(python): pyspark.sql.tests.test_serde (19s)
Starting test(python): pyspark.sql.tests.test_session
Finished test(python): pyspark.mllib.tests.test_streaming_algorithms (120s)
Starting test(python): pyspark.sql.tests.test_streaming
Finished test(python): pyspark.sql.tests.test_readwriter (25s)
Starting test(python): pyspark.sql.tests.test_types
Finished test(python): pyspark.ml.tests.test_tuning (208s)
Starting test(python): pyspark.sql.tests.test_udf
Finished test(python): pyspark.sql.tests.test_session (31s)
Starting test(python): pyspark.sql.tests.test_utils
Finished test(python): pyspark.sql.tests.test_streaming (35s)
Starting test(python): pyspark.streaming.tests.test_context
Finished test(python): pyspark.sql.tests.test_types (34s)
Starting test(python): pyspark.streaming.tests.test_dstream
Finished test(python): pyspark.sql.tests.test_utils (14s)
Starting test(python): pyspark.streaming.tests.test_kinesis
Finished test(python): pyspark.streaming.tests.test_kinesis (0s) ... 2 tests were skipped
Starting test(python): pyspark.streaming.tests.test_listener
Finished test(python): pyspark.streaming.tests.test_listener (11s)
Starting test(python): pyspark.tests.test_appsubmit
Finished test(python): pyspark.sql.tests.test_udf (39s)
Starting test(python): pyspark.tests.test_broadcast
Finished test(python): pyspark.streaming.tests.test_context (23s)
Starting test(python): pyspark.tests.test_conf
Finished test(python): pyspark.tests.test_conf (15s)
Starting test(python): pyspark.tests.test_context
Finished test(python): pyspark.tests.test_broadcast (33s)
Starting test(python): pyspark.tests.test_daemon
Finished test(python): pyspark.tests.test_daemon (5s)
Starting test(python): pyspark.tests.test_install_spark
Finished test(python): pyspark.tests.test_context (44s)
Starting test(python): pyspark.tests.test_join
Finished test(python): pyspark.tests.test_appsubmit (68s)
Starting test(python): pyspark.tests.test_profiler
Finished test(python): pyspark.tests.test_join (7s)
Starting test(python): pyspark.tests.test_rdd
Finished test(python): pyspark.tests.test_profiler (9s)
Starting test(python): pyspark.tests.test_rddbarrier
Finished test(python): pyspark.tests.test_rddbarrier (7s)
Starting test(python): pyspark.tests.test_readwrite
Finished test(python): pyspark.streaming.tests.test_dstream (107s)
Starting test(python): pyspark.tests.test_serializers
Finished test(python): pyspark.tests.test_serializers (8s)
Starting test(python): pyspark.tests.test_shuffle
Finished test(python): pyspark.tests.test_readwrite (14s)
Starting test(python): pyspark.tests.test_taskcontext
Finished test(python): pyspark.tests.test_install_spark (65s)
Starting test(python): pyspark.tests.test_util
Finished test(python): pyspark.tests.test_shuffle (8s)
Starting test(python): pyspark.tests.test_worker
Finished test(python): pyspark.tests.test_util (5s)
Starting test(python): pyspark.accumulators
Finished test(python): pyspark.accumulators (5s)
Starting test(python): pyspark.broadcast
Finished test(python): pyspark.broadcast (6s)
Starting test(python): pyspark.conf
Finished test(python): pyspark.tests.test_worker (14s)
Starting test(python): pyspark.context
Finished test(python): pyspark.conf (4s)
Starting test(python): pyspark.ml.classification
Finished test(python): pyspark.tests.test_rdd (60s)
Starting test(python): pyspark.ml.clustering
Finished test(python): pyspark.context (21s)
Starting test(python): pyspark.ml.evaluation
Finished test(python): pyspark.tests.test_taskcontext (69s)
Starting test(python): pyspark.ml.feature
Finished test(python): pyspark.ml.evaluation (26s)
Starting test(python): pyspark.ml.fpm
Finished test(python): pyspark.ml.clustering (45s)
Starting test(python): pyspark.ml.functions
Finished test(python): pyspark.ml.fpm (24s)
Starting test(python): pyspark.ml.image
Finished test(python): pyspark.ml.functions (17s)
Starting test(python): pyspark.ml.linalg.__init__
Finished test(python): pyspark.ml.linalg.__init__ (0s)
Starting test(python): pyspark.ml.recommendation
Finished test(python): pyspark.ml.classification (74s)
Starting test(python): pyspark.ml.regression
Finished test(python): pyspark.ml.image (8s)
Starting test(python): pyspark.ml.stat
Finished test(python): pyspark.ml.stat (29s)
Starting test(python): pyspark.ml.tuning
Finished test(python): pyspark.ml.regression (53s)
Starting test(python): pyspark.mllib.classification
Finished test(python): pyspark.ml.tuning (35s)
Starting test(python): pyspark.mllib.clustering
Finished test(python): pyspark.ml.feature (103s)
Starting test(python): pyspark.mllib.evaluation
Finished test(python): pyspark.mllib.classification (33s)
Starting test(python): pyspark.mllib.feature
Finished test(python): pyspark.mllib.evaluation (21s)
Starting test(python): pyspark.mllib.fpm
Finished test(python): pyspark.ml.recommendation (103s)
Starting test(python): pyspark.mllib.linalg.__init__
Finished test(python): pyspark.mllib.linalg.__init__ (1s)
Starting test(python): pyspark.mllib.linalg.distributed
Finished test(python): pyspark.mllib.feature (26s)
Starting test(python): pyspark.mllib.random
Finished test(python): pyspark.mllib.fpm (23s)
Starting test(python): pyspark.mllib.recommendation
Finished test(python): pyspark.mllib.clustering (50s)
Starting test(python): pyspark.mllib.regression
Finished test(python): pyspark.mllib.random (13s)
Starting test(python): pyspark.mllib.stat.KernelDensity
Finished test(python): pyspark.mllib.stat.KernelDensity (1s)
Starting test(python): pyspark.mllib.stat._statistics
Finished test(python): pyspark.mllib.linalg.distributed (42s)
Starting test(python): pyspark.mllib.tree
Finished test(python): pyspark.mllib.stat._statistics (19s)
Starting test(python): pyspark.mllib.util
Finished test(python): pyspark.mllib.regression (33s)
Starting test(python): pyspark.profiler
Finished test(python): pyspark.mllib.recommendation (36s)
Starting test(python): pyspark.rdd
Finished test(python): pyspark.profiler (9s)
Starting test(python): pyspark.resource.tests.test_resources
Finished test(python): pyspark.mllib.tree (19s)
Starting test(python): pyspark.serializers
Finished test(python): pyspark.mllib.util (21s)
Starting test(python): pyspark.shuffle
Finished test(python): pyspark.resource.tests.test_resources (9s)
Starting test(python): pyspark.sql.avro.functions
Finished test(python): pyspark.shuffle (1s)
Starting test(python): pyspark.sql.catalog
Finished test(python): pyspark.rdd (22s)
Starting test(python): pyspark.sql.column
Finished test(python): pyspark.serializers (12s)
Starting test(python): pyspark.sql.conf
Finished test(python): pyspark.sql.conf (6s)
Starting test(python): pyspark.sql.context
Finished test(python): pyspark.sql.catalog (14s)
Starting test(python): pyspark.sql.dataframe
Finished test(python): pyspark.sql.avro.functions (15s)
Starting test(python): pyspark.sql.functions
Finished test(python): pyspark.sql.column (24s)
Starting test(python): pyspark.sql.group
Finished test(python): pyspark.sql.context (20s)
Starting test(python): pyspark.sql.pandas.conversion
Finished test(python): pyspark.sql.pandas.conversion (13s)
Starting test(python): pyspark.sql.pandas.group_ops
Finished test(python): pyspark.sql.group (36s)
Starting test(python): pyspark.sql.pandas.map_ops
Finished test(python): pyspark.sql.pandas.group_ops (21s)
Starting test(python): pyspark.sql.pandas.serializers
Finished test(python): pyspark.sql.pandas.serializers (0s)
Starting test(python): pyspark.sql.pandas.typehints
Finished test(python): pyspark.sql.pandas.typehints (0s)
Starting test(python): pyspark.sql.pandas.types
Finished test(python): pyspark.sql.pandas.types (0s)
Starting test(python): pyspark.sql.pandas.utils
Finished test(python): pyspark.sql.pandas.utils (0s)
Starting test(python): pyspark.sql.readwriter
Finished test(python): pyspark.sql.dataframe (56s)
Starting test(python): pyspark.sql.session
Finished test(python): pyspark.sql.functions (57s)
Starting test(python): pyspark.sql.streaming
Finished test(python): pyspark.sql.pandas.map_ops (12s)
Starting test(python): pyspark.sql.types
Finished test(python): pyspark.sql.types (10s)
Starting test(python): pyspark.sql.udf
Finished test(python): pyspark.sql.streaming (16s)
Starting test(python): pyspark.sql.window
Finished test(python): pyspark.sql.session (19s)
Starting test(python): pyspark.streaming.util
Finished test(python): pyspark.streaming.util (0s)
Starting test(python): pyspark.util
Finished test(python): pyspark.util (0s)
Finished test(python): pyspark.sql.readwriter (24s)
Finished test(python): pyspark.sql.udf (13s)
Finished test(python): pyspark.sql.window (14s)
Tests passed in 780 seconds

```

Closes #30277 from HyukjinKwon/SPARK-33371.

Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
2020-11-06 15:05:37 -08:00

332 lines
13 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import random
import stat
import sys
import tempfile
import time
import unittest
from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext
from pyspark.testing.utils import PySparkTestCase, SPARK_HOME
class TaskContextTests(PySparkTestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
# Allow retries even though they are normally disabled in local mode
self.sc = SparkContext('local[4, 2]', class_name)
def test_stage_id(self):
"""Test the stage ids are available and incrementing as expected."""
rdd = self.sc.parallelize(range(10))
stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
# Test using the constructor directly rather than the get()
stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0]
self.assertEqual(stage1 + 1, stage2)
self.assertEqual(stage1 + 2, stage3)
self.assertEqual(stage2 + 1, stage3)
def test_resources(self):
"""Test the resources are empty by default."""
rdd = self.sc.parallelize(range(10))
resources1 = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
# Test using the constructor directly rather than the get()
resources2 = rdd.map(lambda x: TaskContext().resources()).take(1)[0]
self.assertEqual(len(resources1), 0)
self.assertEqual(len(resources2), 0)
def test_partition_id(self):
"""Test the partition id."""
rdd1 = self.sc.parallelize(range(10), 1)
rdd2 = self.sc.parallelize(range(10), 2)
pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect()
pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect()
self.assertEqual(0, pids1[0])
self.assertEqual(0, pids1[9])
self.assertEqual(0, pids2[0])
self.assertEqual(1, pids2[9])
def test_attempt_number(self):
"""Verify the attempt numbers are correctly reported."""
rdd = self.sc.parallelize(range(10))
# Verify a simple job with no failures
attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect()
map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers)
def fail_on_first(x):
"""Fail on the first attempt so we get a positive attempt number"""
tc = TaskContext.get()
attempt_number = tc.attemptNumber()
partition_id = tc.partitionId()
attempt_id = tc.taskAttemptId()
if attempt_number == 0 and partition_id == 0:
raise Exception("Failing on first attempt")
else:
return [x, partition_id, attempt_number, attempt_id]
result = rdd.map(fail_on_first).collect()
# We should re-submit the first partition to it but other partitions should be attempt 0
self.assertEqual([0, 0, 1], result[0][0:3])
self.assertEqual([9, 3, 0], result[9][0:3])
first_partition = filter(lambda x: x[1] == 0, result)
map(lambda x: self.assertEqual(1, x[2]), first_partition)
other_partitions = filter(lambda x: x[1] != 0, result)
map(lambda x: self.assertEqual(0, x[2]), other_partitions)
# The task attempt id should be different
self.assertTrue(result[0][3] != result[9][3])
def test_tc_on_driver(self):
"""Verify that getting the TaskContext on the driver returns None."""
tc = TaskContext.get()
self.assertTrue(tc is None)
def test_get_local_property(self):
"""Verify that local properties set on the driver are available in TaskContext."""
key = "testkey"
value = "testvalue"
self.sc.setLocalProperty(key, value)
try:
rdd = self.sc.parallelize(range(1), 1)
prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0]
self.assertEqual(prop1, value)
prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
self.assertTrue(prop2 is None)
finally:
self.sc.setLocalProperty(key, None)
def test_barrier(self):
"""
Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks
within a stage.
"""
rdd = self.sc.parallelize(range(10), 4)
def f(iterator):
yield sum(iterator)
def context_barrier(x):
tc = BarrierTaskContext.get()
time.sleep(random.randint(1, 5) * 2)
tc.barrier()
return time.time()
times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
self.assertTrue(max(times) - min(times) < 2)
def test_all_gather(self):
"""
Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks
within a stage and passes messages properly.
"""
rdd = self.sc.parallelize(range(10), 4)
def f(iterator):
yield sum(iterator)
def context_barrier(x):
tc = BarrierTaskContext.get()
time.sleep(random.randint(1, 10))
out = tc.allGather(str(tc.partitionId()))
pids = [int(e) for e in out]
return pids
pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0]
self.assertEqual(pids, [0, 1, 2, 3])
def test_barrier_infos(self):
"""
Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
barrier stage.
"""
rdd = self.sc.parallelize(range(10), 4)
def f(iterator):
yield sum(iterator)
taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get()
.getTaskInfos()).collect()
self.assertTrue(len(taskInfos) == 4)
self.assertTrue(len(taskInfos[0]) == 4)
def test_context_get(self):
"""
Verify that TaskContext.get() works both in or not in a barrier stage.
"""
rdd = self.sc.parallelize(range(10), 4)
def f(iterator):
taskContext = TaskContext.get()
if isinstance(taskContext, BarrierTaskContext):
yield taskContext.partitionId() + 1
elif isinstance(taskContext, TaskContext):
yield taskContext.partitionId() + 2
else:
yield -1
# for normal stage
result1 = rdd.mapPartitions(f).collect()
self.assertTrue(result1 == [2, 3, 4, 5])
# for barrier stage
result2 = rdd.barrier().mapPartitions(f).collect()
self.assertTrue(result2 == [1, 2, 3, 4])
def test_barrier_context_get(self):
"""
Verify that BarrierTaskContext.get() should only works in a barrier stage.
"""
rdd = self.sc.parallelize(range(10), 4)
def f(iterator):
try:
taskContext = BarrierTaskContext.get()
except Exception:
yield -1
else:
yield taskContext.partitionId()
# for normal stage
result1 = rdd.mapPartitions(f).collect()
self.assertTrue(result1 == [-1, -1, -1, -1])
# for barrier stage
result2 = rdd.barrier().mapPartitions(f).collect()
self.assertTrue(result2 == [0, 1, 2, 3])
class TaskContextTestsWithWorkerReuse(unittest.TestCase):
def setUp(self):
class_name = self.__class__.__name__
conf = SparkConf().set("spark.python.worker.reuse", "true")
self.sc = SparkContext('local[2]', class_name, conf=conf)
def test_barrier_with_python_worker_reuse(self):
"""
Regression test for SPARK-25921: verify that BarrierTaskContext.barrier() with
reused python worker.
"""
# start a normal job first to start all workers and get all worker pids
worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect()
# the worker will reuse in this barrier job
rdd = self.sc.parallelize(range(10), 2)
def f(iterator):
yield sum(iterator)
def context_barrier(x):
tc = BarrierTaskContext.get()
time.sleep(random.randint(1, 5) * 2)
tc.barrier()
return (time.time(), os.getpid())
result = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
times = list(map(lambda x: x[0], result))
pids = list(map(lambda x: x[1], result))
# check both barrier and worker reuse effect
self.assertTrue(max(times) - min(times) < 2)
for pid in pids:
self.assertTrue(pid in worker_pids)
def test_task_context_correct_with_python_worker_reuse(self):
"""Verify the task context correct when reused python worker"""
# start a normal job first to start all workers and get all worker pids
worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect()
# the worker will reuse in this barrier job
rdd = self.sc.parallelize(range(10), 2)
def context(iterator):
tp = TaskContext.get().partitionId()
try:
bp = BarrierTaskContext.get().partitionId()
except Exception:
bp = -1
yield (tp, bp, os.getpid())
# normal stage after normal stage
normal_result = rdd.mapPartitions(context).collect()
tps, bps, pids = zip(*normal_result)
print(tps)
self.assertTrue(tps == (0, 1))
self.assertTrue(bps == (-1, -1))
for pid in pids:
self.assertTrue(pid in worker_pids)
# barrier stage after normal stage
barrier_result = rdd.barrier().mapPartitions(context).collect()
tps, bps, pids = zip(*barrier_result)
self.assertTrue(tps == (0, 1))
self.assertTrue(bps == (0, 1))
for pid in pids:
self.assertTrue(pid in worker_pids)
# normal stage after barrier stage
normal_result2 = rdd.mapPartitions(context).collect()
tps, bps, pids = zip(*normal_result2)
self.assertTrue(tps == (0, 1))
self.assertTrue(bps == (-1, -1))
for pid in pids:
self.assertTrue(pid in worker_pids)
def tearDown(self):
self.sc.stop()
class TaskContextTestsWithResources(unittest.TestCase):
def setUp(self):
class_name = self.__class__.__name__
self.tempFile = tempfile.NamedTemporaryFile(delete=False)
self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
self.tempFile.close()
# create temporary directory for Worker resources coordination
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.tempdir.name)
os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP |
stat.S_IROTH | stat.S_IXOTH)
conf = SparkConf().set("spark.test.home", SPARK_HOME)
conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name)
conf = conf.set("spark.worker.resource.gpu.amount", 1)
conf = conf.set("spark.task.resource.gpu.amount", "1")
conf = conf.set("spark.executor.resource.gpu.amount", "1")
self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
def test_resources(self):
"""Test the resources are available."""
rdd = self.sc.parallelize(range(10))
resources = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
self.assertEqual(len(resources), 1)
self.assertTrue('gpu' in resources)
self.assertEqual(resources['gpu'].name, 'gpu')
self.assertEqual(resources['gpu'].addresses, ['0'])
def tearDown(self):
os.unlink(self.tempFile.name)
self.sc.stop()
if __name__ == "__main__":
import unittest
from pyspark.tests.test_taskcontext import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)