e11a24c1ba
### What changes were proposed in this pull request? This PR proposes to fix PySpark to officially support Python 3.9. The main codes already work. We should just note that we support Python 3.9. Also, this PR fixes some minor fixes into the test codes. - `Thread.isAlive` is removed in Python 3.9, and `Thread.is_alive` exists in Python 3.6+, see https://docs.python.org/3/whatsnew/3.9.html#removed - Fixed `TaskContextTestsWithWorkerReuse.test_barrier_with_python_worker_reuse` and `TaskContextTests.test_barrier` to be less flaky. This becomes more flaky in Python 3.9 for some reasons. NOTE that PyArrow does not support Python 3.9 yet. ### Why are the changes needed? To officially support Python 3.9. ### Does this PR introduce _any_ user-facing change? Yes, it officially supports Python 3.9. ### How was this patch tested? Manually ran the tests: ``` $ ./run-tests --python-executable=python Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-resource', 'pyspark-sql', 'pyspark-streaming'] python python_implementation is CPython python version is: Python 3.9.0 Starting test(python): pyspark.ml.tests.test_base Starting test(python): pyspark.ml.tests.test_evaluation Starting test(python): pyspark.ml.tests.test_algorithms Starting test(python): pyspark.ml.tests.test_feature Finished test(python): pyspark.ml.tests.test_base (12s) Starting test(python): pyspark.ml.tests.test_image Finished test(python): pyspark.ml.tests.test_evaluation (15s) Starting test(python): pyspark.ml.tests.test_linalg Finished test(python): pyspark.ml.tests.test_feature (25s) Starting test(python): pyspark.ml.tests.test_param Finished test(python): pyspark.ml.tests.test_image (17s) Starting test(python): pyspark.ml.tests.test_persistence Finished test(python): pyspark.ml.tests.test_param (17s) Starting test(python): pyspark.ml.tests.test_pipeline Finished test(python): pyspark.ml.tests.test_linalg (30s) Starting test(python): pyspark.ml.tests.test_stat Finished test(python): pyspark.ml.tests.test_pipeline (6s) Starting test(python): pyspark.ml.tests.test_training_summary Finished test(python): pyspark.ml.tests.test_stat (12s) Starting test(python): pyspark.ml.tests.test_tuning Finished test(python): pyspark.ml.tests.test_algorithms (68s) Starting test(python): pyspark.ml.tests.test_wrapper Finished test(python): pyspark.ml.tests.test_persistence (51s) Starting test(python): pyspark.mllib.tests.test_algorithms Finished test(python): pyspark.ml.tests.test_training_summary (33s) Starting test(python): pyspark.mllib.tests.test_feature Finished test(python): pyspark.ml.tests.test_wrapper (19s) Starting test(python): pyspark.mllib.tests.test_linalg Finished test(python): pyspark.mllib.tests.test_feature (26s) Starting test(python): pyspark.mllib.tests.test_stat Finished test(python): pyspark.mllib.tests.test_stat (22s) Starting test(python): pyspark.mllib.tests.test_streaming_algorithms Finished test(python): pyspark.mllib.tests.test_algorithms (53s) Starting test(python): pyspark.mllib.tests.test_util Finished test(python): pyspark.mllib.tests.test_linalg (54s) Starting test(python): pyspark.sql.tests.test_arrow Finished test(python): pyspark.sql.tests.test_arrow (0s) ... 61 tests were skipped Starting test(python): pyspark.sql.tests.test_catalog Finished test(python): pyspark.mllib.tests.test_util (11s) Starting test(python): pyspark.sql.tests.test_column Finished test(python): pyspark.sql.tests.test_catalog (16s) Starting test(python): pyspark.sql.tests.test_conf Finished test(python): pyspark.sql.tests.test_column (17s) Starting test(python): pyspark.sql.tests.test_context Finished test(python): pyspark.sql.tests.test_context (6s) ... 3 tests were skipped Starting test(python): pyspark.sql.tests.test_dataframe Finished test(python): pyspark.sql.tests.test_conf (11s) Starting test(python): pyspark.sql.tests.test_datasources Finished test(python): pyspark.sql.tests.test_datasources (19s) Starting test(python): pyspark.sql.tests.test_functions Finished test(python): pyspark.sql.tests.test_dataframe (35s) ... 3 tests were skipped Starting test(python): pyspark.sql.tests.test_group Finished test(python): pyspark.sql.tests.test_functions (32s) Starting test(python): pyspark.sql.tests.test_pandas_cogrouped_map Finished test(python): pyspark.sql.tests.test_pandas_cogrouped_map (1s) ... 15 tests were skipped Starting test(python): pyspark.sql.tests.test_pandas_grouped_map Finished test(python): pyspark.sql.tests.test_group (19s) Starting test(python): pyspark.sql.tests.test_pandas_map Finished test(python): pyspark.sql.tests.test_pandas_grouped_map (0s) ... 21 tests were skipped Starting test(python): pyspark.sql.tests.test_pandas_udf Finished test(python): pyspark.sql.tests.test_pandas_map (0s) ... 6 tests were skipped Starting test(python): pyspark.sql.tests.test_pandas_udf_grouped_agg Finished test(python): pyspark.sql.tests.test_pandas_udf (0s) ... 6 tests were skipped Starting test(python): pyspark.sql.tests.test_pandas_udf_scalar Finished test(python): pyspark.sql.tests.test_pandas_udf_grouped_agg (0s) ... 13 tests were skipped Starting test(python): pyspark.sql.tests.test_pandas_udf_typehints Finished test(python): pyspark.sql.tests.test_pandas_udf_scalar (0s) ... 50 tests were skipped Starting test(python): pyspark.sql.tests.test_pandas_udf_window Finished test(python): pyspark.sql.tests.test_pandas_udf_typehints (0s) ... 10 tests were skipped Starting test(python): pyspark.sql.tests.test_readwriter Finished test(python): pyspark.sql.tests.test_pandas_udf_window (0s) ... 14 tests were skipped Starting test(python): pyspark.sql.tests.test_serde Finished test(python): pyspark.sql.tests.test_serde (19s) Starting test(python): pyspark.sql.tests.test_session Finished test(python): pyspark.mllib.tests.test_streaming_algorithms (120s) Starting test(python): pyspark.sql.tests.test_streaming Finished test(python): pyspark.sql.tests.test_readwriter (25s) Starting test(python): pyspark.sql.tests.test_types Finished test(python): pyspark.ml.tests.test_tuning (208s) Starting test(python): pyspark.sql.tests.test_udf Finished test(python): pyspark.sql.tests.test_session (31s) Starting test(python): pyspark.sql.tests.test_utils Finished test(python): pyspark.sql.tests.test_streaming (35s) Starting test(python): pyspark.streaming.tests.test_context Finished test(python): pyspark.sql.tests.test_types (34s) Starting test(python): pyspark.streaming.tests.test_dstream Finished test(python): pyspark.sql.tests.test_utils (14s) Starting test(python): pyspark.streaming.tests.test_kinesis Finished test(python): pyspark.streaming.tests.test_kinesis (0s) ... 2 tests were skipped Starting test(python): pyspark.streaming.tests.test_listener Finished test(python): pyspark.streaming.tests.test_listener (11s) Starting test(python): pyspark.tests.test_appsubmit Finished test(python): pyspark.sql.tests.test_udf (39s) Starting test(python): pyspark.tests.test_broadcast Finished test(python): pyspark.streaming.tests.test_context (23s) Starting test(python): pyspark.tests.test_conf Finished test(python): pyspark.tests.test_conf (15s) Starting test(python): pyspark.tests.test_context Finished test(python): pyspark.tests.test_broadcast (33s) Starting test(python): pyspark.tests.test_daemon Finished test(python): pyspark.tests.test_daemon (5s) Starting test(python): pyspark.tests.test_install_spark Finished test(python): pyspark.tests.test_context (44s) Starting test(python): pyspark.tests.test_join Finished test(python): pyspark.tests.test_appsubmit (68s) Starting test(python): pyspark.tests.test_profiler Finished test(python): pyspark.tests.test_join (7s) Starting test(python): pyspark.tests.test_rdd Finished test(python): pyspark.tests.test_profiler (9s) Starting test(python): pyspark.tests.test_rddbarrier Finished test(python): pyspark.tests.test_rddbarrier (7s) Starting test(python): pyspark.tests.test_readwrite Finished test(python): pyspark.streaming.tests.test_dstream (107s) Starting test(python): pyspark.tests.test_serializers Finished test(python): pyspark.tests.test_serializers (8s) Starting test(python): pyspark.tests.test_shuffle Finished test(python): pyspark.tests.test_readwrite (14s) Starting test(python): pyspark.tests.test_taskcontext Finished test(python): pyspark.tests.test_install_spark (65s) Starting test(python): pyspark.tests.test_util Finished test(python): pyspark.tests.test_shuffle (8s) Starting test(python): pyspark.tests.test_worker Finished test(python): pyspark.tests.test_util (5s) Starting test(python): pyspark.accumulators Finished test(python): pyspark.accumulators (5s) Starting test(python): pyspark.broadcast Finished test(python): pyspark.broadcast (6s) Starting test(python): pyspark.conf Finished test(python): pyspark.tests.test_worker (14s) Starting test(python): pyspark.context Finished test(python): pyspark.conf (4s) Starting test(python): pyspark.ml.classification Finished test(python): pyspark.tests.test_rdd (60s) Starting test(python): pyspark.ml.clustering Finished test(python): pyspark.context (21s) Starting test(python): pyspark.ml.evaluation Finished test(python): pyspark.tests.test_taskcontext (69s) Starting test(python): pyspark.ml.feature Finished test(python): pyspark.ml.evaluation (26s) Starting test(python): pyspark.ml.fpm Finished test(python): pyspark.ml.clustering (45s) Starting test(python): pyspark.ml.functions Finished test(python): pyspark.ml.fpm (24s) Starting test(python): pyspark.ml.image Finished test(python): pyspark.ml.functions (17s) Starting test(python): pyspark.ml.linalg.__init__ Finished test(python): pyspark.ml.linalg.__init__ (0s) Starting test(python): pyspark.ml.recommendation Finished test(python): pyspark.ml.classification (74s) Starting test(python): pyspark.ml.regression Finished test(python): pyspark.ml.image (8s) Starting test(python): pyspark.ml.stat Finished test(python): pyspark.ml.stat (29s) Starting test(python): pyspark.ml.tuning Finished test(python): pyspark.ml.regression (53s) Starting test(python): pyspark.mllib.classification Finished test(python): pyspark.ml.tuning (35s) Starting test(python): pyspark.mllib.clustering Finished test(python): pyspark.ml.feature (103s) Starting test(python): pyspark.mllib.evaluation Finished test(python): pyspark.mllib.classification (33s) Starting test(python): pyspark.mllib.feature Finished test(python): pyspark.mllib.evaluation (21s) Starting test(python): pyspark.mllib.fpm Finished test(python): pyspark.ml.recommendation (103s) Starting test(python): pyspark.mllib.linalg.__init__ Finished test(python): pyspark.mllib.linalg.__init__ (1s) Starting test(python): pyspark.mllib.linalg.distributed Finished test(python): pyspark.mllib.feature (26s) Starting test(python): pyspark.mllib.random Finished test(python): pyspark.mllib.fpm (23s) Starting test(python): pyspark.mllib.recommendation Finished test(python): pyspark.mllib.clustering (50s) Starting test(python): pyspark.mllib.regression Finished test(python): pyspark.mllib.random (13s) Starting test(python): pyspark.mllib.stat.KernelDensity Finished test(python): pyspark.mllib.stat.KernelDensity (1s) Starting test(python): pyspark.mllib.stat._statistics Finished test(python): pyspark.mllib.linalg.distributed (42s) Starting test(python): pyspark.mllib.tree Finished test(python): pyspark.mllib.stat._statistics (19s) Starting test(python): pyspark.mllib.util Finished test(python): pyspark.mllib.regression (33s) Starting test(python): pyspark.profiler Finished test(python): pyspark.mllib.recommendation (36s) Starting test(python): pyspark.rdd Finished test(python): pyspark.profiler (9s) Starting test(python): pyspark.resource.tests.test_resources Finished test(python): pyspark.mllib.tree (19s) Starting test(python): pyspark.serializers Finished test(python): pyspark.mllib.util (21s) Starting test(python): pyspark.shuffle Finished test(python): pyspark.resource.tests.test_resources (9s) Starting test(python): pyspark.sql.avro.functions Finished test(python): pyspark.shuffle (1s) Starting test(python): pyspark.sql.catalog Finished test(python): pyspark.rdd (22s) Starting test(python): pyspark.sql.column Finished test(python): pyspark.serializers (12s) Starting test(python): pyspark.sql.conf Finished test(python): pyspark.sql.conf (6s) Starting test(python): pyspark.sql.context Finished test(python): pyspark.sql.catalog (14s) Starting test(python): pyspark.sql.dataframe Finished test(python): pyspark.sql.avro.functions (15s) Starting test(python): pyspark.sql.functions Finished test(python): pyspark.sql.column (24s) Starting test(python): pyspark.sql.group Finished test(python): pyspark.sql.context (20s) Starting test(python): pyspark.sql.pandas.conversion Finished test(python): pyspark.sql.pandas.conversion (13s) Starting test(python): pyspark.sql.pandas.group_ops Finished test(python): pyspark.sql.group (36s) Starting test(python): pyspark.sql.pandas.map_ops Finished test(python): pyspark.sql.pandas.group_ops (21s) Starting test(python): pyspark.sql.pandas.serializers Finished test(python): pyspark.sql.pandas.serializers (0s) Starting test(python): pyspark.sql.pandas.typehints Finished test(python): pyspark.sql.pandas.typehints (0s) Starting test(python): pyspark.sql.pandas.types Finished test(python): pyspark.sql.pandas.types (0s) Starting test(python): pyspark.sql.pandas.utils Finished test(python): pyspark.sql.pandas.utils (0s) Starting test(python): pyspark.sql.readwriter Finished test(python): pyspark.sql.dataframe (56s) Starting test(python): pyspark.sql.session Finished test(python): pyspark.sql.functions (57s) Starting test(python): pyspark.sql.streaming Finished test(python): pyspark.sql.pandas.map_ops (12s) Starting test(python): pyspark.sql.types Finished test(python): pyspark.sql.types (10s) Starting test(python): pyspark.sql.udf Finished test(python): pyspark.sql.streaming (16s) Starting test(python): pyspark.sql.window Finished test(python): pyspark.sql.session (19s) Starting test(python): pyspark.streaming.util Finished test(python): pyspark.streaming.util (0s) Starting test(python): pyspark.util Finished test(python): pyspark.util (0s) Finished test(python): pyspark.sql.readwriter (24s) Finished test(python): pyspark.sql.udf (13s) Finished test(python): pyspark.sql.window (14s) Tests passed in 780 seconds ``` Closes #30277 from HyukjinKwon/SPARK-33371. Authored-by: HyukjinKwon <gurwls223@apache.org> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
332 lines
13 KiB
Python
332 lines
13 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import os
|
|
import random
|
|
import stat
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import unittest
|
|
|
|
from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext
|
|
from pyspark.testing.utils import PySparkTestCase, SPARK_HOME
|
|
|
|
|
|
class TaskContextTests(PySparkTestCase):
|
|
|
|
def setUp(self):
|
|
self._old_sys_path = list(sys.path)
|
|
class_name = self.__class__.__name__
|
|
# Allow retries even though they are normally disabled in local mode
|
|
self.sc = SparkContext('local[4, 2]', class_name)
|
|
|
|
def test_stage_id(self):
|
|
"""Test the stage ids are available and incrementing as expected."""
|
|
rdd = self.sc.parallelize(range(10))
|
|
stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
|
|
stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
|
|
# Test using the constructor directly rather than the get()
|
|
stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0]
|
|
self.assertEqual(stage1 + 1, stage2)
|
|
self.assertEqual(stage1 + 2, stage3)
|
|
self.assertEqual(stage2 + 1, stage3)
|
|
|
|
def test_resources(self):
|
|
"""Test the resources are empty by default."""
|
|
rdd = self.sc.parallelize(range(10))
|
|
resources1 = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
|
|
# Test using the constructor directly rather than the get()
|
|
resources2 = rdd.map(lambda x: TaskContext().resources()).take(1)[0]
|
|
self.assertEqual(len(resources1), 0)
|
|
self.assertEqual(len(resources2), 0)
|
|
|
|
def test_partition_id(self):
|
|
"""Test the partition id."""
|
|
rdd1 = self.sc.parallelize(range(10), 1)
|
|
rdd2 = self.sc.parallelize(range(10), 2)
|
|
pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect()
|
|
pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect()
|
|
self.assertEqual(0, pids1[0])
|
|
self.assertEqual(0, pids1[9])
|
|
self.assertEqual(0, pids2[0])
|
|
self.assertEqual(1, pids2[9])
|
|
|
|
def test_attempt_number(self):
|
|
"""Verify the attempt numbers are correctly reported."""
|
|
rdd = self.sc.parallelize(range(10))
|
|
# Verify a simple job with no failures
|
|
attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect()
|
|
map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers)
|
|
|
|
def fail_on_first(x):
|
|
"""Fail on the first attempt so we get a positive attempt number"""
|
|
tc = TaskContext.get()
|
|
attempt_number = tc.attemptNumber()
|
|
partition_id = tc.partitionId()
|
|
attempt_id = tc.taskAttemptId()
|
|
if attempt_number == 0 and partition_id == 0:
|
|
raise Exception("Failing on first attempt")
|
|
else:
|
|
return [x, partition_id, attempt_number, attempt_id]
|
|
result = rdd.map(fail_on_first).collect()
|
|
# We should re-submit the first partition to it but other partitions should be attempt 0
|
|
self.assertEqual([0, 0, 1], result[0][0:3])
|
|
self.assertEqual([9, 3, 0], result[9][0:3])
|
|
first_partition = filter(lambda x: x[1] == 0, result)
|
|
map(lambda x: self.assertEqual(1, x[2]), first_partition)
|
|
other_partitions = filter(lambda x: x[1] != 0, result)
|
|
map(lambda x: self.assertEqual(0, x[2]), other_partitions)
|
|
# The task attempt id should be different
|
|
self.assertTrue(result[0][3] != result[9][3])
|
|
|
|
def test_tc_on_driver(self):
|
|
"""Verify that getting the TaskContext on the driver returns None."""
|
|
tc = TaskContext.get()
|
|
self.assertTrue(tc is None)
|
|
|
|
def test_get_local_property(self):
|
|
"""Verify that local properties set on the driver are available in TaskContext."""
|
|
key = "testkey"
|
|
value = "testvalue"
|
|
self.sc.setLocalProperty(key, value)
|
|
try:
|
|
rdd = self.sc.parallelize(range(1), 1)
|
|
prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0]
|
|
self.assertEqual(prop1, value)
|
|
prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
|
|
self.assertTrue(prop2 is None)
|
|
finally:
|
|
self.sc.setLocalProperty(key, None)
|
|
|
|
def test_barrier(self):
|
|
"""
|
|
Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks
|
|
within a stage.
|
|
"""
|
|
rdd = self.sc.parallelize(range(10), 4)
|
|
|
|
def f(iterator):
|
|
yield sum(iterator)
|
|
|
|
def context_barrier(x):
|
|
tc = BarrierTaskContext.get()
|
|
time.sleep(random.randint(1, 5) * 2)
|
|
tc.barrier()
|
|
return time.time()
|
|
|
|
times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
|
|
self.assertTrue(max(times) - min(times) < 2)
|
|
|
|
def test_all_gather(self):
|
|
"""
|
|
Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks
|
|
within a stage and passes messages properly.
|
|
"""
|
|
rdd = self.sc.parallelize(range(10), 4)
|
|
|
|
def f(iterator):
|
|
yield sum(iterator)
|
|
|
|
def context_barrier(x):
|
|
tc = BarrierTaskContext.get()
|
|
time.sleep(random.randint(1, 10))
|
|
out = tc.allGather(str(tc.partitionId()))
|
|
pids = [int(e) for e in out]
|
|
return pids
|
|
|
|
pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0]
|
|
self.assertEqual(pids, [0, 1, 2, 3])
|
|
|
|
def test_barrier_infos(self):
|
|
"""
|
|
Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
|
|
barrier stage.
|
|
"""
|
|
rdd = self.sc.parallelize(range(10), 4)
|
|
|
|
def f(iterator):
|
|
yield sum(iterator)
|
|
|
|
taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get()
|
|
.getTaskInfos()).collect()
|
|
self.assertTrue(len(taskInfos) == 4)
|
|
self.assertTrue(len(taskInfos[0]) == 4)
|
|
|
|
def test_context_get(self):
|
|
"""
|
|
Verify that TaskContext.get() works both in or not in a barrier stage.
|
|
"""
|
|
rdd = self.sc.parallelize(range(10), 4)
|
|
|
|
def f(iterator):
|
|
taskContext = TaskContext.get()
|
|
if isinstance(taskContext, BarrierTaskContext):
|
|
yield taskContext.partitionId() + 1
|
|
elif isinstance(taskContext, TaskContext):
|
|
yield taskContext.partitionId() + 2
|
|
else:
|
|
yield -1
|
|
|
|
# for normal stage
|
|
result1 = rdd.mapPartitions(f).collect()
|
|
self.assertTrue(result1 == [2, 3, 4, 5])
|
|
# for barrier stage
|
|
result2 = rdd.barrier().mapPartitions(f).collect()
|
|
self.assertTrue(result2 == [1, 2, 3, 4])
|
|
|
|
def test_barrier_context_get(self):
|
|
"""
|
|
Verify that BarrierTaskContext.get() should only works in a barrier stage.
|
|
"""
|
|
rdd = self.sc.parallelize(range(10), 4)
|
|
|
|
def f(iterator):
|
|
try:
|
|
taskContext = BarrierTaskContext.get()
|
|
except Exception:
|
|
yield -1
|
|
else:
|
|
yield taskContext.partitionId()
|
|
|
|
# for normal stage
|
|
result1 = rdd.mapPartitions(f).collect()
|
|
self.assertTrue(result1 == [-1, -1, -1, -1])
|
|
# for barrier stage
|
|
result2 = rdd.barrier().mapPartitions(f).collect()
|
|
self.assertTrue(result2 == [0, 1, 2, 3])
|
|
|
|
|
|
class TaskContextTestsWithWorkerReuse(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
class_name = self.__class__.__name__
|
|
conf = SparkConf().set("spark.python.worker.reuse", "true")
|
|
self.sc = SparkContext('local[2]', class_name, conf=conf)
|
|
|
|
def test_barrier_with_python_worker_reuse(self):
|
|
"""
|
|
Regression test for SPARK-25921: verify that BarrierTaskContext.barrier() with
|
|
reused python worker.
|
|
"""
|
|
# start a normal job first to start all workers and get all worker pids
|
|
worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect()
|
|
# the worker will reuse in this barrier job
|
|
rdd = self.sc.parallelize(range(10), 2)
|
|
|
|
def f(iterator):
|
|
yield sum(iterator)
|
|
|
|
def context_barrier(x):
|
|
tc = BarrierTaskContext.get()
|
|
time.sleep(random.randint(1, 5) * 2)
|
|
tc.barrier()
|
|
return (time.time(), os.getpid())
|
|
|
|
result = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
|
|
times = list(map(lambda x: x[0], result))
|
|
pids = list(map(lambda x: x[1], result))
|
|
# check both barrier and worker reuse effect
|
|
self.assertTrue(max(times) - min(times) < 2)
|
|
for pid in pids:
|
|
self.assertTrue(pid in worker_pids)
|
|
|
|
def test_task_context_correct_with_python_worker_reuse(self):
|
|
"""Verify the task context correct when reused python worker"""
|
|
# start a normal job first to start all workers and get all worker pids
|
|
worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect()
|
|
# the worker will reuse in this barrier job
|
|
rdd = self.sc.parallelize(range(10), 2)
|
|
|
|
def context(iterator):
|
|
tp = TaskContext.get().partitionId()
|
|
try:
|
|
bp = BarrierTaskContext.get().partitionId()
|
|
except Exception:
|
|
bp = -1
|
|
|
|
yield (tp, bp, os.getpid())
|
|
|
|
# normal stage after normal stage
|
|
normal_result = rdd.mapPartitions(context).collect()
|
|
tps, bps, pids = zip(*normal_result)
|
|
print(tps)
|
|
self.assertTrue(tps == (0, 1))
|
|
self.assertTrue(bps == (-1, -1))
|
|
for pid in pids:
|
|
self.assertTrue(pid in worker_pids)
|
|
# barrier stage after normal stage
|
|
barrier_result = rdd.barrier().mapPartitions(context).collect()
|
|
tps, bps, pids = zip(*barrier_result)
|
|
self.assertTrue(tps == (0, 1))
|
|
self.assertTrue(bps == (0, 1))
|
|
for pid in pids:
|
|
self.assertTrue(pid in worker_pids)
|
|
# normal stage after barrier stage
|
|
normal_result2 = rdd.mapPartitions(context).collect()
|
|
tps, bps, pids = zip(*normal_result2)
|
|
self.assertTrue(tps == (0, 1))
|
|
self.assertTrue(bps == (-1, -1))
|
|
for pid in pids:
|
|
self.assertTrue(pid in worker_pids)
|
|
|
|
def tearDown(self):
|
|
self.sc.stop()
|
|
|
|
|
|
class TaskContextTestsWithResources(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
class_name = self.__class__.__name__
|
|
self.tempFile = tempfile.NamedTemporaryFile(delete=False)
|
|
self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
|
|
self.tempFile.close()
|
|
# create temporary directory for Worker resources coordination
|
|
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
|
|
os.unlink(self.tempdir.name)
|
|
os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP |
|
|
stat.S_IROTH | stat.S_IXOTH)
|
|
conf = SparkConf().set("spark.test.home", SPARK_HOME)
|
|
conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name)
|
|
conf = conf.set("spark.worker.resource.gpu.amount", 1)
|
|
conf = conf.set("spark.task.resource.gpu.amount", "1")
|
|
conf = conf.set("spark.executor.resource.gpu.amount", "1")
|
|
self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
|
|
|
|
def test_resources(self):
|
|
"""Test the resources are available."""
|
|
rdd = self.sc.parallelize(range(10))
|
|
resources = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
|
|
self.assertEqual(len(resources), 1)
|
|
self.assertTrue('gpu' in resources)
|
|
self.assertEqual(resources['gpu'].name, 'gpu')
|
|
self.assertEqual(resources['gpu'].addresses, ['0'])
|
|
|
|
def tearDown(self):
|
|
os.unlink(self.tempFile.name)
|
|
self.sc.stop()
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
from pyspark.tests.test_taskcontext import * # noqa: F401
|
|
|
|
try:
|
|
import xmlrunner # type: ignore[import]
|
|
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
|
except ImportError:
|
|
testRunner = None
|
|
unittest.main(testRunner=testRunner, verbosity=2)
|