2016-12-20 18:51:21 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
2018-08-28 21:47:38 -04:00
|
|
|
from pyspark.java_gateway import local_connect_and_auth
|
2020-04-17 00:23:32 -04:00
|
|
|
from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
|
2016-12-20 18:51:21 -05:00
|
|
|
|
|
|
|
|
|
|
|
class TaskContext(object):
|
|
|
|
|
|
|
|
"""
|
|
|
|
Contextual information about a task which can be read or mutated during
|
|
|
|
execution. To access the TaskContext for a running task, use:
|
2019-07-05 13:08:22 -04:00
|
|
|
:meth:`TaskContext.get`.
|
2016-12-20 18:51:21 -05:00
|
|
|
"""
|
|
|
|
|
|
|
|
_taskContext = None
|
|
|
|
|
|
|
|
_attemptNumber = None
|
|
|
|
_partitionId = None
|
|
|
|
_stageId = None
|
|
|
|
_taskAttemptId = None
|
2018-05-31 14:23:57 -04:00
|
|
|
_localProperties = None
|
2019-07-10 20:32:58 -04:00
|
|
|
_resources = None
|
2016-12-20 18:51:21 -05:00
|
|
|
|
|
|
|
def __new__(cls):
|
|
|
|
"""Even if users construct TaskContext instead of using get, give them the singleton."""
|
|
|
|
taskContext = cls._taskContext
|
|
|
|
if taskContext is not None:
|
|
|
|
return taskContext
|
|
|
|
cls._taskContext = taskContext = object.__new__(cls)
|
|
|
|
return taskContext
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _getOrCreate(cls):
|
|
|
|
"""Internal function to get or create global TaskContext."""
|
|
|
|
if cls._taskContext is None:
|
|
|
|
cls._taskContext = TaskContext()
|
|
|
|
return cls._taskContext
|
|
|
|
|
2019-10-31 00:10:44 -04:00
|
|
|
@classmethod
|
|
|
|
def _setTaskContext(cls, taskContext):
|
|
|
|
cls._taskContext = taskContext
|
|
|
|
|
2016-12-20 18:51:21 -05:00
|
|
|
@classmethod
|
|
|
|
def get(cls):
|
|
|
|
"""
|
|
|
|
Return the currently active TaskContext. This can be called inside of
|
|
|
|
user functions to access contextual information about running tasks.
|
|
|
|
|
|
|
|
.. note:: Must be called on the worker, not the driver. Returns None if not initialized.
|
|
|
|
"""
|
|
|
|
return cls._taskContext
|
|
|
|
|
|
|
|
def stageId(self):
|
|
|
|
"""The ID of the stage that this task belong to."""
|
|
|
|
return self._stageId
|
|
|
|
|
|
|
|
def partitionId(self):
|
|
|
|
"""
|
|
|
|
The ID of the RDD partition that is computed by this task.
|
|
|
|
"""
|
|
|
|
return self._partitionId
|
|
|
|
|
|
|
|
def attemptNumber(self):
|
|
|
|
""""
|
|
|
|
How many times this task has been attempted. The first task attempt will be assigned
|
|
|
|
attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
|
|
|
|
"""
|
|
|
|
return self._attemptNumber
|
|
|
|
|
|
|
|
def taskAttemptId(self):
|
|
|
|
"""
|
|
|
|
An ID that is unique to this task attempt (within the same SparkContext, no two task
|
|
|
|
attempts will share the same attempt ID). This is roughly equivalent to Hadoop's
|
|
|
|
TaskAttemptID.
|
|
|
|
"""
|
|
|
|
return self._taskAttemptId
|
2018-05-31 14:23:57 -04:00
|
|
|
|
|
|
|
def getLocalProperty(self, key):
|
|
|
|
"""
|
|
|
|
Get a local property set upstream in the driver, or None if it is missing.
|
|
|
|
"""
|
|
|
|
return self._localProperties.get(key, None)
|
2018-08-21 18:54:30 -04:00
|
|
|
|
2019-07-10 20:32:58 -04:00
|
|
|
def resources(self):
|
|
|
|
"""
|
|
|
|
Resources allocated to the task. The key is the resource name and the value is information
|
|
|
|
about the resource.
|
|
|
|
"""
|
|
|
|
return self._resources
|
|
|
|
|
2018-08-21 18:54:30 -04:00
|
|
|
|
|
|
|
BARRIER_FUNCTION = 1
|
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext
Fix for #27395
### What changes were proposed in this pull request?
The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call.
### Why are the changes needed?
There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on.
### Does this PR introduce any user-facing change?
Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs.
### How was this patch tested?
Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID.
An example through the Python API:
```python
>>> from pyspark import BarrierTaskContext
>>>
>>> def f(iterator):
... context = BarrierTaskContext.get()
... return [context.allGather('{}'.format(context.partitionId()))]
...
>>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0]
[u'3', u'1', u'0', u'2']
```
Closes #27640 from sarthfrey/master.
Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com>
Co-authored-by: sarthfrey <sarth.frey@gmail.com>
Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
|
|
|
ALL_GATHER_FUNCTION = 2
|
2018-08-21 18:54:30 -04:00
|
|
|
|
|
|
|
|
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext
Fix for #27395
### What changes were proposed in this pull request?
The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call.
### Why are the changes needed?
There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on.
### Does this PR introduce any user-facing change?
Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs.
### How was this patch tested?
Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID.
An example through the Python API:
```python
>>> from pyspark import BarrierTaskContext
>>>
>>> def f(iterator):
... context = BarrierTaskContext.get()
... return [context.allGather('{}'.format(context.partitionId()))]
...
>>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0]
[u'3', u'1', u'0', u'2']
```
Closes #27640 from sarthfrey/master.
Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com>
Co-authored-by: sarthfrey <sarth.frey@gmail.com>
Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
|
|
|
def _load_from_socket(port, auth_secret, function, all_gather_message=None):
|
2018-08-21 18:54:30 -04:00
|
|
|
"""
|
|
|
|
Load data from a given socket, this is a blocking method thus only return when the socket
|
|
|
|
connection has been closed.
|
|
|
|
"""
|
2018-08-28 21:47:38 -04:00
|
|
|
(sockfile, sock) = local_connect_and_auth(port, auth_secret)
|
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext
Fix for #27395
### What changes were proposed in this pull request?
The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call.
### Why are the changes needed?
There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on.
### Does this PR introduce any user-facing change?
Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs.
### How was this patch tested?
Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID.
An example through the Python API:
```python
>>> from pyspark import BarrierTaskContext
>>>
>>> def f(iterator):
... context = BarrierTaskContext.get()
... return [context.allGather('{}'.format(context.partitionId()))]
...
>>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0]
[u'3', u'1', u'0', u'2']
```
Closes #27640 from sarthfrey/master.
Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com>
Co-authored-by: sarthfrey <sarth.frey@gmail.com>
Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
|
|
|
|
|
|
|
# The call may block forever, so no timeout
|
2018-08-28 21:47:38 -04:00
|
|
|
sock.settimeout(None)
|
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext
Fix for #27395
### What changes were proposed in this pull request?
The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call.
### Why are the changes needed?
There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on.
### Does this PR introduce any user-facing change?
Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs.
### How was this patch tested?
Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID.
An example through the Python API:
```python
>>> from pyspark import BarrierTaskContext
>>>
>>> def f(iterator):
... context = BarrierTaskContext.get()
... return [context.allGather('{}'.format(context.partitionId()))]
...
>>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0]
[u'3', u'1', u'0', u'2']
```
Closes #27640 from sarthfrey/master.
Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com>
Co-authored-by: sarthfrey <sarth.frey@gmail.com>
Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
|
|
|
|
|
|
|
if function == BARRIER_FUNCTION:
|
|
|
|
# Make a barrier() function call.
|
|
|
|
write_int(function, sockfile)
|
|
|
|
elif function == ALL_GATHER_FUNCTION:
|
|
|
|
# Make a all_gather() function call.
|
|
|
|
write_int(function, sockfile)
|
|
|
|
write_with_length(all_gather_message.encode("utf-8"), sockfile)
|
|
|
|
else:
|
|
|
|
raise ValueError("Unrecognized function type")
|
2018-08-21 18:54:30 -04:00
|
|
|
sockfile.flush()
|
|
|
|
|
|
|
|
# Collect result.
|
2020-04-17 00:23:32 -04:00
|
|
|
len = read_int(sockfile)
|
|
|
|
res = []
|
|
|
|
for i in range(len):
|
|
|
|
res.append(UTF8Deserializer().loads(sockfile))
|
2018-08-21 18:54:30 -04:00
|
|
|
|
|
|
|
# Release resources.
|
|
|
|
sockfile.close()
|
|
|
|
sock.close()
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
class BarrierTaskContext(TaskContext):
|
|
|
|
|
|
|
|
"""
|
|
|
|
.. note:: Experimental
|
|
|
|
|
2018-08-29 10:22:03 -04:00
|
|
|
A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage.
|
|
|
|
Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task.
|
2018-08-21 18:54:30 -04:00
|
|
|
|
|
|
|
.. versionadded:: 2.4.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
_port = None
|
|
|
|
_secret = None
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _getOrCreate(cls):
|
2019-01-11 01:28:37 -05:00
|
|
|
"""
|
|
|
|
Internal function to get or create global BarrierTaskContext. We need to make sure
|
|
|
|
BarrierTaskContext is returned from here because it is needed in python worker reuse
|
|
|
|
scenario, see SPARK-25921 for more details.
|
|
|
|
"""
|
2018-11-13 04:05:39 -05:00
|
|
|
if not isinstance(cls._taskContext, BarrierTaskContext):
|
|
|
|
cls._taskContext = object.__new__(cls)
|
2018-08-21 18:54:30 -04:00
|
|
|
return cls._taskContext
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get(cls):
|
|
|
|
"""
|
2018-08-29 10:22:03 -04:00
|
|
|
.. note:: Experimental
|
|
|
|
|
|
|
|
Return the currently active :class:`BarrierTaskContext`.
|
|
|
|
This can be called inside of user functions to access contextual information about
|
|
|
|
running tasks.
|
2018-08-21 18:54:30 -04:00
|
|
|
|
|
|
|
.. note:: Must be called on the worker, not the driver. Returns None if not initialized.
|
2019-10-31 00:10:44 -04:00
|
|
|
An Exception will raise if it is not in a barrier stage.
|
2018-08-21 18:54:30 -04:00
|
|
|
"""
|
2019-10-31 00:10:44 -04:00
|
|
|
if not isinstance(cls._taskContext, BarrierTaskContext):
|
|
|
|
raise Exception('It is not in a barrier stage')
|
2018-08-21 18:54:30 -04:00
|
|
|
return cls._taskContext
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _initialize(cls, port, secret):
|
|
|
|
"""
|
|
|
|
Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called
|
|
|
|
after BarrierTaskContext is initialized.
|
|
|
|
"""
|
|
|
|
cls._port = port
|
|
|
|
cls._secret = secret
|
|
|
|
|
|
|
|
def barrier(self):
|
|
|
|
"""
|
|
|
|
.. note:: Experimental
|
|
|
|
|
|
|
|
Sets a global barrier and waits until all tasks in this stage hit this barrier.
|
2018-08-29 10:22:03 -04:00
|
|
|
Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks
|
|
|
|
in the same stage have reached this routine.
|
|
|
|
|
|
|
|
.. warning:: In a barrier stage, each task much have the same number of `barrier()`
|
|
|
|
calls, in all possible code branches.
|
|
|
|
Otherwise, you may get the job hanging or a SparkException after timeout.
|
2018-08-21 18:54:30 -04:00
|
|
|
|
|
|
|
.. versionadded:: 2.4.0
|
|
|
|
"""
|
|
|
|
if self._port is None or self._secret is None:
|
|
|
|
raise Exception("Not supported to call barrier() before initialize " +
|
|
|
|
"BarrierTaskContext.")
|
|
|
|
else:
|
[SPARK-30667][CORE] Add all gather method to BarrierTaskContext
Fix for #27395
### What changes were proposed in this pull request?
The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call.
### Why are the changes needed?
There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on.
### Does this PR introduce any user-facing change?
Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs.
### How was this patch tested?
Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID.
An example through the Python API:
```python
>>> from pyspark import BarrierTaskContext
>>>
>>> def f(iterator):
... context = BarrierTaskContext.get()
... return [context.allGather('{}'.format(context.partitionId()))]
...
>>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0]
[u'3', u'1', u'0', u'2']
```
Closes #27640 from sarthfrey/master.
Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com>
Co-authored-by: sarthfrey <sarth.frey@gmail.com>
Signed-off-by: Xingbo Jiang <xingbo.jiang@databricks.com>
2020-02-21 14:40:28 -05:00
|
|
|
_load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
|
|
|
|
|
|
|
|
def allGather(self, message=""):
|
|
|
|
"""
|
|
|
|
.. note:: Experimental
|
|
|
|
|
|
|
|
This function blocks until all tasks in the same stage have reached this routine.
|
|
|
|
Each task passes in a message and returns with a list of all the messages passed in
|
|
|
|
by each of those tasks.
|
|
|
|
|
|
|
|
.. warning:: In a barrier stage, each task much have the same number of `allGather()`
|
|
|
|
calls, in all possible code branches.
|
|
|
|
Otherwise, you may get the job hanging or a SparkException after timeout.
|
|
|
|
|
|
|
|
.. versionadded:: 3.0.0
|
|
|
|
"""
|
|
|
|
if not isinstance(message, str):
|
|
|
|
raise ValueError("Argument `message` must be of type `str`")
|
|
|
|
elif self._port is None or self._secret is None:
|
|
|
|
raise Exception("Not supported to call barrier() before initialize " +
|
|
|
|
"BarrierTaskContext.")
|
|
|
|
else:
|
2020-04-17 00:23:32 -04:00
|
|
|
return _load_from_socket(self._port, self._secret, ALL_GATHER_FUNCTION, message)
|
2018-08-21 18:54:30 -04:00
|
|
|
|
|
|
|
def getTaskInfos(self):
|
|
|
|
"""
|
|
|
|
.. note:: Experimental
|
|
|
|
|
2018-08-29 10:22:03 -04:00
|
|
|
Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage,
|
|
|
|
ordered by partition ID.
|
2018-08-21 18:54:30 -04:00
|
|
|
|
|
|
|
.. versionadded:: 2.4.0
|
|
|
|
"""
|
|
|
|
if self._port is None or self._secret is None:
|
|
|
|
raise Exception("Not supported to call getTaskInfos() before initialize " +
|
|
|
|
"BarrierTaskContext.")
|
|
|
|
else:
|
|
|
|
addresses = self._localProperties.get("addresses", "")
|
|
|
|
return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")]
|
|
|
|
|
|
|
|
|
|
|
|
class BarrierTaskInfo(object):
|
|
|
|
"""
|
|
|
|
.. note:: Experimental
|
|
|
|
|
|
|
|
Carries all task infos of a barrier task.
|
|
|
|
|
2018-08-29 10:22:03 -04:00
|
|
|
:var address: The IPv4 address (host:port) of the executor that the barrier task is running on
|
|
|
|
|
2018-08-21 18:54:30 -04:00
|
|
|
.. versionadded:: 2.4.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, address):
|
|
|
|
self.address = address
|