Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 from __future__ import print_function
0019 import json
0020 
0021 from pyspark.java_gateway import local_connect_and_auth
0022 from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
0023 
0024 
0025 class TaskContext(object):
0026 
0027     """
0028     Contextual information about a task which can be read or mutated during
0029     execution. To access the TaskContext for a running task, use:
0030     :meth:`TaskContext.get`.
0031     """
0032 
0033     _taskContext = None
0034 
0035     _attemptNumber = None
0036     _partitionId = None
0037     _stageId = None
0038     _taskAttemptId = None
0039     _localProperties = None
0040     _resources = None
0041 
0042     def __new__(cls):
0043         """Even if users construct TaskContext instead of using get, give them the singleton."""
0044         taskContext = cls._taskContext
0045         if taskContext is not None:
0046             return taskContext
0047         cls._taskContext = taskContext = object.__new__(cls)
0048         return taskContext
0049 
0050     @classmethod
0051     def _getOrCreate(cls):
0052         """Internal function to get or create global TaskContext."""
0053         if cls._taskContext is None:
0054             cls._taskContext = TaskContext()
0055         return cls._taskContext
0056 
0057     @classmethod
0058     def _setTaskContext(cls, taskContext):
0059         cls._taskContext = taskContext
0060 
0061     @classmethod
0062     def get(cls):
0063         """
0064         Return the currently active TaskContext. This can be called inside of
0065         user functions to access contextual information about running tasks.
0066 
0067         .. note:: Must be called on the worker, not the driver. Returns None if not initialized.
0068         """
0069         return cls._taskContext
0070 
0071     def stageId(self):
0072         """The ID of the stage that this task belong to."""
0073         return self._stageId
0074 
0075     def partitionId(self):
0076         """
0077         The ID of the RDD partition that is computed by this task.
0078         """
0079         return self._partitionId
0080 
0081     def attemptNumber(self):
0082         """"
0083         How many times this task has been attempted.  The first task attempt will be assigned
0084         attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
0085         """
0086         return self._attemptNumber
0087 
0088     def taskAttemptId(self):
0089         """
0090         An ID that is unique to this task attempt (within the same SparkContext, no two task
0091         attempts will share the same attempt ID).  This is roughly equivalent to Hadoop's
0092         TaskAttemptID.
0093         """
0094         return self._taskAttemptId
0095 
0096     def getLocalProperty(self, key):
0097         """
0098         Get a local property set upstream in the driver, or None if it is missing.
0099         """
0100         return self._localProperties.get(key, None)
0101 
0102     def resources(self):
0103         """
0104         Resources allocated to the task. The key is the resource name and the value is information
0105         about the resource.
0106         """
0107         return self._resources
0108 
0109 
0110 BARRIER_FUNCTION = 1
0111 ALL_GATHER_FUNCTION = 2
0112 
0113 
0114 def _load_from_socket(port, auth_secret, function, all_gather_message=None):
0115     """
0116     Load data from a given socket, this is a blocking method thus only return when the socket
0117     connection has been closed.
0118     """
0119     (sockfile, sock) = local_connect_and_auth(port, auth_secret)
0120 
0121     # The call may block forever, so no timeout
0122     sock.settimeout(None)
0123 
0124     if function == BARRIER_FUNCTION:
0125         # Make a barrier() function call.
0126         write_int(function, sockfile)
0127     elif function == ALL_GATHER_FUNCTION:
0128         # Make a all_gather() function call.
0129         write_int(function, sockfile)
0130         write_with_length(all_gather_message.encode("utf-8"), sockfile)
0131     else:
0132         raise ValueError("Unrecognized function type")
0133     sockfile.flush()
0134 
0135     # Collect result.
0136     len = read_int(sockfile)
0137     res = []
0138     for i in range(len):
0139         res.append(UTF8Deserializer().loads(sockfile))
0140 
0141     # Release resources.
0142     sockfile.close()
0143     sock.close()
0144 
0145     return res
0146 
0147 
0148 class BarrierTaskContext(TaskContext):
0149 
0150     """
0151     .. note:: Experimental
0152 
0153     A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage.
0154     Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task.
0155 
0156     .. versionadded:: 2.4.0
0157     """
0158 
0159     _port = None
0160     _secret = None
0161 
0162     @classmethod
0163     def _getOrCreate(cls):
0164         """
0165         Internal function to get or create global BarrierTaskContext. We need to make sure
0166         BarrierTaskContext is returned from here because it is needed in python worker reuse
0167         scenario, see SPARK-25921 for more details.
0168         """
0169         if not isinstance(cls._taskContext, BarrierTaskContext):
0170             cls._taskContext = object.__new__(cls)
0171         return cls._taskContext
0172 
0173     @classmethod
0174     def get(cls):
0175         """
0176         .. note:: Experimental
0177 
0178         Return the currently active :class:`BarrierTaskContext`.
0179         This can be called inside of user functions to access contextual information about
0180         running tasks.
0181 
0182         .. note:: Must be called on the worker, not the driver. Returns None if not initialized.
0183             An Exception will raise if it is not in a barrier stage.
0184         """
0185         if not isinstance(cls._taskContext, BarrierTaskContext):
0186             raise Exception('It is not in a barrier stage')
0187         return cls._taskContext
0188 
0189     @classmethod
0190     def _initialize(cls, port, secret):
0191         """
0192         Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called
0193         after BarrierTaskContext is initialized.
0194         """
0195         cls._port = port
0196         cls._secret = secret
0197 
0198     def barrier(self):
0199         """
0200         .. note:: Experimental
0201 
0202         Sets a global barrier and waits until all tasks in this stage hit this barrier.
0203         Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks
0204         in the same stage have reached this routine.
0205 
0206         .. warning:: In a barrier stage, each task much have the same number of `barrier()`
0207             calls, in all possible code branches.
0208             Otherwise, you may get the job hanging or a SparkException after timeout.
0209 
0210         .. versionadded:: 2.4.0
0211         """
0212         if self._port is None or self._secret is None:
0213             raise Exception("Not supported to call barrier() before initialize " +
0214                             "BarrierTaskContext.")
0215         else:
0216             _load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
0217 
0218     def allGather(self, message=""):
0219         """
0220         .. note:: Experimental
0221 
0222         This function blocks until all tasks in the same stage have reached this routine.
0223         Each task passes in a message and returns with a list of all the messages passed in
0224         by each of those tasks.
0225 
0226         .. warning:: In a barrier stage, each task much have the same number of `allGather()`
0227             calls, in all possible code branches.
0228             Otherwise, you may get the job hanging or a SparkException after timeout.
0229 
0230         .. versionadded:: 3.0.0
0231         """
0232         if not isinstance(message, str):
0233             raise ValueError("Argument `message` must be of type `str`")
0234         elif self._port is None or self._secret is None:
0235             raise Exception("Not supported to call barrier() before initialize " +
0236                             "BarrierTaskContext.")
0237         else:
0238             return _load_from_socket(self._port, self._secret, ALL_GATHER_FUNCTION, message)
0239 
0240     def getTaskInfos(self):
0241         """
0242         .. note:: Experimental
0243 
0244         Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage,
0245         ordered by partition ID.
0246 
0247         .. versionadded:: 2.4.0
0248         """
0249         if self._port is None or self._secret is None:
0250             raise Exception("Not supported to call getTaskInfos() before initialize " +
0251                             "BarrierTaskContext.")
0252         else:
0253             addresses = self._localProperties.get("addresses", "")
0254             return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")]
0255 
0256 
0257 class BarrierTaskInfo(object):
0258     """
0259     .. note:: Experimental
0260 
0261     Carries all task infos of a barrier task.
0262 
0263     :var address: The IPv4 address (host:port) of the executor that the barrier task is running on
0264 
0265     .. versionadded:: 2.4.0
0266     """
0267 
0268     def __init__(self, address):
0269         self.address = address