0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from __future__ import print_function
0019 import json
0020
0021 from pyspark.java_gateway import local_connect_and_auth
0022 from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
0023
0024
0025 class TaskContext(object):
0026
0027 """
0028 Contextual information about a task which can be read or mutated during
0029 execution. To access the TaskContext for a running task, use:
0030 :meth:`TaskContext.get`.
0031 """
0032
0033 _taskContext = None
0034
0035 _attemptNumber = None
0036 _partitionId = None
0037 _stageId = None
0038 _taskAttemptId = None
0039 _localProperties = None
0040 _resources = None
0041
0042 def __new__(cls):
0043 """Even if users construct TaskContext instead of using get, give them the singleton."""
0044 taskContext = cls._taskContext
0045 if taskContext is not None:
0046 return taskContext
0047 cls._taskContext = taskContext = object.__new__(cls)
0048 return taskContext
0049
0050 @classmethod
0051 def _getOrCreate(cls):
0052 """Internal function to get or create global TaskContext."""
0053 if cls._taskContext is None:
0054 cls._taskContext = TaskContext()
0055 return cls._taskContext
0056
0057 @classmethod
0058 def _setTaskContext(cls, taskContext):
0059 cls._taskContext = taskContext
0060
0061 @classmethod
0062 def get(cls):
0063 """
0064 Return the currently active TaskContext. This can be called inside of
0065 user functions to access contextual information about running tasks.
0066
0067 .. note:: Must be called on the worker, not the driver. Returns None if not initialized.
0068 """
0069 return cls._taskContext
0070
0071 def stageId(self):
0072 """The ID of the stage that this task belong to."""
0073 return self._stageId
0074
0075 def partitionId(self):
0076 """
0077 The ID of the RDD partition that is computed by this task.
0078 """
0079 return self._partitionId
0080
0081 def attemptNumber(self):
0082 """"
0083 How many times this task has been attempted. The first task attempt will be assigned
0084 attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
0085 """
0086 return self._attemptNumber
0087
0088 def taskAttemptId(self):
0089 """
0090 An ID that is unique to this task attempt (within the same SparkContext, no two task
0091 attempts will share the same attempt ID). This is roughly equivalent to Hadoop's
0092 TaskAttemptID.
0093 """
0094 return self._taskAttemptId
0095
0096 def getLocalProperty(self, key):
0097 """
0098 Get a local property set upstream in the driver, or None if it is missing.
0099 """
0100 return self._localProperties.get(key, None)
0101
0102 def resources(self):
0103 """
0104 Resources allocated to the task. The key is the resource name and the value is information
0105 about the resource.
0106 """
0107 return self._resources
0108
0109
0110 BARRIER_FUNCTION = 1
0111 ALL_GATHER_FUNCTION = 2
0112
0113
0114 def _load_from_socket(port, auth_secret, function, all_gather_message=None):
0115 """
0116 Load data from a given socket, this is a blocking method thus only return when the socket
0117 connection has been closed.
0118 """
0119 (sockfile, sock) = local_connect_and_auth(port, auth_secret)
0120
0121
0122 sock.settimeout(None)
0123
0124 if function == BARRIER_FUNCTION:
0125
0126 write_int(function, sockfile)
0127 elif function == ALL_GATHER_FUNCTION:
0128
0129 write_int(function, sockfile)
0130 write_with_length(all_gather_message.encode("utf-8"), sockfile)
0131 else:
0132 raise ValueError("Unrecognized function type")
0133 sockfile.flush()
0134
0135
0136 len = read_int(sockfile)
0137 res = []
0138 for i in range(len):
0139 res.append(UTF8Deserializer().loads(sockfile))
0140
0141
0142 sockfile.close()
0143 sock.close()
0144
0145 return res
0146
0147
0148 class BarrierTaskContext(TaskContext):
0149
0150 """
0151 .. note:: Experimental
0152
0153 A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage.
0154 Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task.
0155
0156 .. versionadded:: 2.4.0
0157 """
0158
0159 _port = None
0160 _secret = None
0161
0162 @classmethod
0163 def _getOrCreate(cls):
0164 """
0165 Internal function to get or create global BarrierTaskContext. We need to make sure
0166 BarrierTaskContext is returned from here because it is needed in python worker reuse
0167 scenario, see SPARK-25921 for more details.
0168 """
0169 if not isinstance(cls._taskContext, BarrierTaskContext):
0170 cls._taskContext = object.__new__(cls)
0171 return cls._taskContext
0172
0173 @classmethod
0174 def get(cls):
0175 """
0176 .. note:: Experimental
0177
0178 Return the currently active :class:`BarrierTaskContext`.
0179 This can be called inside of user functions to access contextual information about
0180 running tasks.
0181
0182 .. note:: Must be called on the worker, not the driver. Returns None if not initialized.
0183 An Exception will raise if it is not in a barrier stage.
0184 """
0185 if not isinstance(cls._taskContext, BarrierTaskContext):
0186 raise Exception('It is not in a barrier stage')
0187 return cls._taskContext
0188
0189 @classmethod
0190 def _initialize(cls, port, secret):
0191 """
0192 Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called
0193 after BarrierTaskContext is initialized.
0194 """
0195 cls._port = port
0196 cls._secret = secret
0197
0198 def barrier(self):
0199 """
0200 .. note:: Experimental
0201
0202 Sets a global barrier and waits until all tasks in this stage hit this barrier.
0203 Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks
0204 in the same stage have reached this routine.
0205
0206 .. warning:: In a barrier stage, each task much have the same number of `barrier()`
0207 calls, in all possible code branches.
0208 Otherwise, you may get the job hanging or a SparkException after timeout.
0209
0210 .. versionadded:: 2.4.0
0211 """
0212 if self._port is None or self._secret is None:
0213 raise Exception("Not supported to call barrier() before initialize " +
0214 "BarrierTaskContext.")
0215 else:
0216 _load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
0217
0218 def allGather(self, message=""):
0219 """
0220 .. note:: Experimental
0221
0222 This function blocks until all tasks in the same stage have reached this routine.
0223 Each task passes in a message and returns with a list of all the messages passed in
0224 by each of those tasks.
0225
0226 .. warning:: In a barrier stage, each task much have the same number of `allGather()`
0227 calls, in all possible code branches.
0228 Otherwise, you may get the job hanging or a SparkException after timeout.
0229
0230 .. versionadded:: 3.0.0
0231 """
0232 if not isinstance(message, str):
0233 raise ValueError("Argument `message` must be of type `str`")
0234 elif self._port is None or self._secret is None:
0235 raise Exception("Not supported to call barrier() before initialize " +
0236 "BarrierTaskContext.")
0237 else:
0238 return _load_from_socket(self._port, self._secret, ALL_GATHER_FUNCTION, message)
0239
0240 def getTaskInfos(self):
0241 """
0242 .. note:: Experimental
0243
0244 Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage,
0245 ordered by partition ID.
0246
0247 .. versionadded:: 2.4.0
0248 """
0249 if self._port is None or self._secret is None:
0250 raise Exception("Not supported to call getTaskInfos() before initialize " +
0251 "BarrierTaskContext.")
0252 else:
0253 addresses = self._localProperties.get("addresses", "")
0254 return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")]
0255
0256
0257 class BarrierTaskInfo(object):
0258 """
0259 .. note:: Experimental
0260
0261 Carries all task infos of a barrier task.
0262
0263 :var address: The IPv4 address (host:port) of the executor that the barrier task is running on
0264
0265 .. versionadded:: 2.4.0
0266 """
0267
0268 def __init__(self, address):
0269 self.address = address