spark-instrumented-optimizer/python/pyspark/taskcontext.py

283 lines
9 KiB
Python
Raw Normal View History

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
class TaskContext(object):
"""
Contextual information about a task which can be read or mutated during
execution. To access the TaskContext for a running task, use:
:meth:`TaskContext.get`.
"""
_taskContext = None
_attemptNumber = None
_partitionId = None
_stageId = None
_taskAttemptId = None
_localProperties = None
_resources = None
def __new__(cls):
"""Even if users construct TaskContext instead of using get, give them the singleton."""
taskContext = cls._taskContext
if taskContext is not None:
return taskContext
cls._taskContext = taskContext = object.__new__(cls)
return taskContext
@classmethod
def _getOrCreate(cls):
"""Internal function to get or create global TaskContext."""
if cls._taskContext is None:
cls._taskContext = TaskContext()
return cls._taskContext
[SPARK-29582][PYSPARK] Support `TaskContext.get()` in a barrier task from Python side ### What changes were proposed in this pull request? Add support of `TaskContext.get()` in a barrier task from Python side, this makes it easier to migrate legacy user code to barrier execution mode. ### Why are the changes needed? In Spark Core, there is a `TaskContext` object which is a singleton. We set a task context instance which can be TaskContext or BarrierTaskContext before the task function startup, and unset it to none after the function end. So we can both get TaskContext and BarrierTaskContext with the object. However we can only get the BarrierTaskContext with `BarrierTaskContext`, we will get `None` if we get it by `TaskContext.get` in a barrier stage. This is useful when people switch from normal code to barrier code, and only need a little update. ### Does this PR introduce any user-facing change? Yes. Previously: ```python def func(iterator): task_context = TaskContext.get() . # this could be None. barrier_task_context = BarrierTaskContext.get() # get the BarrierTaskContext instance ... rdd.barrier().mapPartitions(func) ``` Proposed: ```python def func(iterator): task_context = TaskContext.get() . # this could also get the BarrierTaskContext instance which is same as barrier_task_context barrier_task_context = BarrierTaskContext.get() # get the BarrierTaskContext instance ... rdd.barrier().mapPartitions(func) ``` ### How was this patch tested? New UT tests. Closes #26239 from ConeyLiu/barrier_task_context. Authored-by: Xianyang Liu <xianyang.liu@intel.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2019-10-31 00:10:44 -04:00
@classmethod
def _setTaskContext(cls, taskContext):
cls._taskContext = taskContext
@classmethod
def get(cls):
"""
Return the currently active TaskContext. This can be called inside of
user functions to access contextual information about running tasks.
Notes
-----
Must be called on the worker, not the driver. Returns None if not initialized.
"""
return cls._taskContext
def stageId(self):
"""The ID of the stage that this task belong to."""
return self._stageId
def partitionId(self):
"""
The ID of the RDD partition that is computed by this task.
"""
return self._partitionId
def attemptNumber(self):
""""
How many times this task has been attempted. The first task attempt will be assigned
attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
"""
return self._attemptNumber
def taskAttemptId(self):
"""
An ID that is unique to this task attempt (within the same SparkContext, no two task
attempts will share the same attempt ID). This is roughly equivalent to Hadoop's
TaskAttemptID.
"""
return self._taskAttemptId
def getLocalProperty(self, key):
"""
Get a local property set upstream in the driver, or None if it is missing.
"""
return self._localProperties.get(key, None)
def resources(self):
"""
Resources allocated to the task. The key is the resource name and the value is information
about the resource.
"""
return self._resources
BARRIER_FUNCTION = 1
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext Fix for #27395 ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27640 from sarthfrey/master. Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com> Co-authored-by: sarthfrey <sarth.frey@gmail.com> Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
ALL_GATHER_FUNCTION = 2
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext Fix for #27395 ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27640 from sarthfrey/master. Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com> Co-authored-by: sarthfrey <sarth.frey@gmail.com> Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
def _load_from_socket(port, auth_secret, function, all_gather_message=None):
"""
Load data from a given socket, this is a blocking method thus only return when the socket
connection has been closed.
"""
(sockfile, sock) = local_connect_and_auth(port, auth_secret)
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext Fix for #27395 ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27640 from sarthfrey/master. Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com> Co-authored-by: sarthfrey <sarth.frey@gmail.com> Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
# The call may block forever, so no timeout
sock.settimeout(None)
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext Fix for #27395 ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27640 from sarthfrey/master. Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com> Co-authored-by: sarthfrey <sarth.frey@gmail.com> Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
if function == BARRIER_FUNCTION:
# Make a barrier() function call.
write_int(function, sockfile)
elif function == ALL_GATHER_FUNCTION:
# Make a all_gather() function call.
write_int(function, sockfile)
write_with_length(all_gather_message.encode("utf-8"), sockfile)
else:
raise ValueError("Unrecognized function type")
sockfile.flush()
# Collect result.
len = read_int(sockfile)
res = []
for i in range(len):
res.append(UTF8Deserializer().loads(sockfile))
# Release resources.
sockfile.close()
sock.close()
return res
class BarrierTaskContext(TaskContext):
"""
A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage.
Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task.
.. versionadded:: 2.4.0
Notes
-----
This API is experimental
"""
_port = None
_secret = None
@classmethod
def _getOrCreate(cls):
"""
Internal function to get or create global BarrierTaskContext. We need to make sure
BarrierTaskContext is returned from here because it is needed in python worker reuse
scenario, see SPARK-25921 for more details.
"""
if not isinstance(cls._taskContext, BarrierTaskContext):
cls._taskContext = object.__new__(cls)
return cls._taskContext
@classmethod
def get(cls):
"""
Return the currently active :class:`BarrierTaskContext`.
This can be called inside of user functions to access contextual information about
running tasks.
Notes
-----
Must be called on the worker, not the driver. Returns None if not initialized.
An Exception will raise if it is not in a barrier stage.
This API is experimental
"""
[SPARK-29582][PYSPARK] Support `TaskContext.get()` in a barrier task from Python side ### What changes were proposed in this pull request? Add support of `TaskContext.get()` in a barrier task from Python side, this makes it easier to migrate legacy user code to barrier execution mode. ### Why are the changes needed? In Spark Core, there is a `TaskContext` object which is a singleton. We set a task context instance which can be TaskContext or BarrierTaskContext before the task function startup, and unset it to none after the function end. So we can both get TaskContext and BarrierTaskContext with the object. However we can only get the BarrierTaskContext with `BarrierTaskContext`, we will get `None` if we get it by `TaskContext.get` in a barrier stage. This is useful when people switch from normal code to barrier code, and only need a little update. ### Does this PR introduce any user-facing change? Yes. Previously: ```python def func(iterator): task_context = TaskContext.get() . # this could be None. barrier_task_context = BarrierTaskContext.get() # get the BarrierTaskContext instance ... rdd.barrier().mapPartitions(func) ``` Proposed: ```python def func(iterator): task_context = TaskContext.get() . # this could also get the BarrierTaskContext instance which is same as barrier_task_context barrier_task_context = BarrierTaskContext.get() # get the BarrierTaskContext instance ... rdd.barrier().mapPartitions(func) ``` ### How was this patch tested? New UT tests. Closes #26239 from ConeyLiu/barrier_task_context. Authored-by: Xianyang Liu <xianyang.liu@intel.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2019-10-31 00:10:44 -04:00
if not isinstance(cls._taskContext, BarrierTaskContext):
raise RuntimeError('It is not in a barrier stage')
return cls._taskContext
@classmethod
def _initialize(cls, port, secret):
"""
Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called
after BarrierTaskContext is initialized.
"""
cls._port = port
cls._secret = secret
def barrier(self):
"""
Sets a global barrier and waits until all tasks in this stage hit this barrier.
Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks
in the same stage have reached this routine.
.. versionadded:: 2.4.0
.. warning:: In a barrier stage, each task much have the same number of `barrier()`
calls, in all possible code branches.
Otherwise, you may get the job hanging or a SparkException after timeout.
Notes
-----
This API is experimental
"""
if self._port is None or self._secret is None:
raise RuntimeError("Not supported to call barrier() before initialize " +
"BarrierTaskContext.")
else:
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext Fix for #27395 ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27640 from sarthfrey/master. Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com> Co-authored-by: sarthfrey <sarth.frey@gmail.com> Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
_load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
def allGather(self, message=""):
"""
This function blocks until all tasks in the same stage have reached this routine.
Each task passes in a message and returns with a list of all the messages passed in
by each of those tasks.
.. versionadded:: 3.0.0
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext Fix for #27395 ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27640 from sarthfrey/master. Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com> Co-authored-by: sarthfrey <sarth.frey@gmail.com> Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
.. warning:: In a barrier stage, each task much have the same number of `allGather()`
calls, in all possible code branches.
Otherwise, you may get the job hanging or a SparkException after timeout.
Notes
-----
This API is experimental
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext Fix for #27395 ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27640 from sarthfrey/master. Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com> Co-authored-by: sarthfrey <sarth.frey@gmail.com> Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
"""
if not isinstance(message, str):
raise TypeError("Argument `message` must be of type `str`")
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext Fix for #27395 ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27640 from sarthfrey/master. Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com> Co-authored-by: sarthfrey <sarth.frey@gmail.com> Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
elif self._port is None or self._secret is None:
raise RuntimeError("Not supported to call barrier() before initialize " +
"BarrierTaskContext.")
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext Fix for #27395 ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27640 from sarthfrey/master. Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com> Co-authored-by: sarthfrey <sarth.frey@gmail.com> Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
else:
return _load_from_socket(self._port, self._secret, ALL_GATHER_FUNCTION, message)
def getTaskInfos(self):
"""
Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage,
ordered by partition ID.
.. versionadded:: 2.4.0
Notes
-----
This API is experimental
"""
if self._port is None or self._secret is None:
raise RuntimeError("Not supported to call getTaskInfos() before initialize " +
"BarrierTaskContext.")
else:
addresses = self._localProperties.get("addresses", "")
return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")]
class BarrierTaskInfo(object):
"""
Carries all task infos of a barrier task.
.. versionadded:: 2.4.0
Attributes
----------
address : str
The IPv4 address (host:port) of the executor that the barrier task is running on
Notes
-----
This API is experimental
"""
def __init__(self, address):
self.address = address