Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 """
0019 Worker that receives input from Piped RDD.
0020 """
0021 from __future__ import print_function
0022 from __future__ import absolute_import
0023 import os
0024 import sys
0025 import time
0026 # 'resource' is a Unix specific module.
0027 has_resource_module = True
0028 try:
0029     import resource
0030 except ImportError:
0031     has_resource_module = False
0032 import traceback
0033 
0034 from pyspark.accumulators import _accumulatorRegistry
0035 from pyspark.broadcast import Broadcast, _broadcastRegistry
0036 from pyspark.java_gateway import local_connect_and_auth
0037 from pyspark.taskcontext import BarrierTaskContext, TaskContext
0038 from pyspark.files import SparkFiles
0039 from pyspark.resource import ResourceInformation
0040 from pyspark.rdd import PythonEvalType
0041 from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
0042     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
0043     BatchedSerializer
0044 from pyspark.sql.pandas.serializers import ArrowStreamPandasUDFSerializer, CogroupUDFSerializer
0045 from pyspark.sql.pandas.types import to_arrow_type
0046 from pyspark.sql.types import StructType
0047 from pyspark.util import _get_argspec, fail_on_stopiteration
0048 from pyspark import shuffle
0049 
0050 if sys.version >= '3':
0051     basestring = str
0052 else:
0053     from itertools import imap as map  # use iterator map by default
0054 
0055 pickleSer = PickleSerializer()
0056 utf8_deserializer = UTF8Deserializer()
0057 
0058 
0059 def report_times(outfile, boot, init, finish):
0060     write_int(SpecialLengths.TIMING_DATA, outfile)
0061     write_long(int(1000 * boot), outfile)
0062     write_long(int(1000 * init), outfile)
0063     write_long(int(1000 * finish), outfile)
0064 
0065 
0066 def add_path(path):
0067     # worker can be used, so donot add path multiple times
0068     if path not in sys.path:
0069         # overwrite system packages
0070         sys.path.insert(1, path)
0071 
0072 
0073 def read_command(serializer, file):
0074     command = serializer._read_with_length(file)
0075     if isinstance(command, Broadcast):
0076         command = serializer.loads(command.value)
0077     return command
0078 
0079 
0080 def chain(f, g):
0081     """chain two functions together """
0082     return lambda *a: g(f(*a))
0083 
0084 
0085 def wrap_udf(f, return_type):
0086     if return_type.needConversion():
0087         toInternal = return_type.toInternal
0088         return lambda *a: toInternal(f(*a))
0089     else:
0090         return lambda *a: f(*a)
0091 
0092 
0093 def wrap_scalar_pandas_udf(f, return_type):
0094     arrow_return_type = to_arrow_type(return_type)
0095 
0096     def verify_result_type(result):
0097         if not hasattr(result, "__len__"):
0098             pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series"
0099             raise TypeError("Return type of the user-defined function should be "
0100                             "{}, but is {}".format(pd_type, type(result)))
0101         return result
0102 
0103     def verify_result_length(result, length):
0104         if len(result) != length:
0105             raise RuntimeError("Result vector from pandas_udf was not the required length: "
0106                                "expected %d, got %d" % (length, len(result)))
0107         return result
0108 
0109     return lambda *a: (verify_result_length(
0110         verify_result_type(f(*a)), len(a[0])), arrow_return_type)
0111 
0112 
0113 def wrap_pandas_iter_udf(f, return_type):
0114     arrow_return_type = to_arrow_type(return_type)
0115 
0116     def verify_result_type(result):
0117         if not hasattr(result, "__len__"):
0118             pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series"
0119             raise TypeError("Return type of the user-defined function should be "
0120                             "{}, but is {}".format(pd_type, type(result)))
0121         return result
0122 
0123     return lambda *iterator: map(lambda res: (res, arrow_return_type),
0124                                  map(verify_result_type, f(*iterator)))
0125 
0126 
0127 def wrap_cogrouped_map_pandas_udf(f, return_type, argspec):
0128 
0129     def wrapped(left_key_series, left_value_series, right_key_series, right_value_series):
0130         import pandas as pd
0131 
0132         left_df = pd.concat(left_value_series, axis=1)
0133         right_df = pd.concat(right_value_series, axis=1)
0134 
0135         if len(argspec.args) == 2:
0136             result = f(left_df, right_df)
0137         elif len(argspec.args) == 3:
0138             key_series = left_key_series if not left_df.empty else right_key_series
0139             key = tuple(s[0] for s in key_series)
0140             result = f(key, left_df, right_df)
0141         if not isinstance(result, pd.DataFrame):
0142             raise TypeError("Return type of the user-defined function should be "
0143                             "pandas.DataFrame, but is {}".format(type(result)))
0144         if not len(result.columns) == len(return_type):
0145             raise RuntimeError(
0146                 "Number of columns of the returned pandas.DataFrame "
0147                 "doesn't match specified schema. "
0148                 "Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
0149         return result
0150 
0151     return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), to_arrow_type(return_type))]
0152 
0153 
0154 def wrap_grouped_map_pandas_udf(f, return_type, argspec):
0155 
0156     def wrapped(key_series, value_series):
0157         import pandas as pd
0158 
0159         if len(argspec.args) == 1:
0160             result = f(pd.concat(value_series, axis=1))
0161         elif len(argspec.args) == 2:
0162             key = tuple(s[0] for s in key_series)
0163             result = f(key, pd.concat(value_series, axis=1))
0164 
0165         if not isinstance(result, pd.DataFrame):
0166             raise TypeError("Return type of the user-defined function should be "
0167                             "pandas.DataFrame, but is {}".format(type(result)))
0168         if not len(result.columns) == len(return_type):
0169             raise RuntimeError(
0170                 "Number of columns of the returned pandas.DataFrame "
0171                 "doesn't match specified schema. "
0172                 "Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
0173         return result
0174 
0175     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
0176 
0177 
0178 def wrap_grouped_agg_pandas_udf(f, return_type):
0179     arrow_return_type = to_arrow_type(return_type)
0180 
0181     def wrapped(*series):
0182         import pandas as pd
0183         result = f(*series)
0184         return pd.Series([result])
0185 
0186     return lambda *a: (wrapped(*a), arrow_return_type)
0187 
0188 
0189 def wrap_window_agg_pandas_udf(f, return_type, runner_conf, udf_index):
0190     window_bound_types_str = runner_conf.get('pandas_window_bound_types')
0191     window_bound_type = [t.strip().lower() for t in window_bound_types_str.split(',')][udf_index]
0192     if window_bound_type == 'bounded':
0193         return wrap_bounded_window_agg_pandas_udf(f, return_type)
0194     elif window_bound_type == 'unbounded':
0195         return wrap_unbounded_window_agg_pandas_udf(f, return_type)
0196     else:
0197         raise RuntimeError("Invalid window bound type: {} ".format(window_bound_type))
0198 
0199 
0200 def wrap_unbounded_window_agg_pandas_udf(f, return_type):
0201     # This is similar to grouped_agg_pandas_udf, the only difference
0202     # is that window_agg_pandas_udf needs to repeat the return value
0203     # to match window length, where grouped_agg_pandas_udf just returns
0204     # the scalar value.
0205     arrow_return_type = to_arrow_type(return_type)
0206 
0207     def wrapped(*series):
0208         import pandas as pd
0209         result = f(*series)
0210         return pd.Series([result]).repeat(len(series[0]))
0211 
0212     return lambda *a: (wrapped(*a), arrow_return_type)
0213 
0214 
0215 def wrap_bounded_window_agg_pandas_udf(f, return_type):
0216     arrow_return_type = to_arrow_type(return_type)
0217 
0218     def wrapped(begin_index, end_index, *series):
0219         import pandas as pd
0220         result = []
0221 
0222         # Index operation is faster on np.ndarray,
0223         # So we turn the index series into np array
0224         # here for performance
0225         begin_array = begin_index.values
0226         end_array = end_index.values
0227 
0228         for i in range(len(begin_array)):
0229             # Note: Create a slice from a series for each window is
0230             #       actually pretty expensive. However, there
0231             #       is no easy way to reduce cost here.
0232             # Note: s.iloc[i : j] is about 30% faster than s[i: j], with
0233             #       the caveat that the created slices shares the same
0234             #       memory with s. Therefore, user are not allowed to
0235             #       change the value of input series inside the window
0236             #       function. It is rare that user needs to modify the
0237             #       input series in the window function, and therefore,
0238             #       it is be a reasonable restriction.
0239             # Note: Calling reset_index on the slices will increase the cost
0240             #       of creating slices by about 100%. Therefore, for performance
0241             #       reasons we don't do it here.
0242             series_slices = [s.iloc[begin_array[i]: end_array[i]] for s in series]
0243             result.append(f(*series_slices))
0244         return pd.Series(result)
0245 
0246     return lambda *a: (wrapped(*a), arrow_return_type)
0247 
0248 
0249 def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
0250     num_arg = read_int(infile)
0251     arg_offsets = [read_int(infile) for i in range(num_arg)]
0252     chained_func = None
0253     for i in range(read_int(infile)):
0254         f, return_type = read_command(pickleSer, infile)
0255         if chained_func is None:
0256             chained_func = f
0257         else:
0258             chained_func = chain(chained_func, f)
0259 
0260     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
0261         func = chained_func
0262     else:
0263         # make sure StopIteration's raised in the user code are not ignored
0264         # when they are processed in a for loop, raise them as RuntimeError's instead
0265         func = fail_on_stopiteration(chained_func)
0266 
0267     # the last returnType will be the return type of UDF
0268     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
0269         return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
0270     elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
0271         return arg_offsets, wrap_pandas_iter_udf(func, return_type)
0272     elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
0273         return arg_offsets, wrap_pandas_iter_udf(func, return_type)
0274     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
0275         argspec = _get_argspec(chained_func)  # signature was lost when wrapping it
0276         return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
0277     elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
0278         argspec = _get_argspec(chained_func)  # signature was lost when wrapping it
0279         return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec)
0280     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
0281         return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
0282     elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
0283         return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index)
0284     elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
0285         return arg_offsets, wrap_udf(func, return_type)
0286     else:
0287         raise ValueError("Unknown eval type: {}".format(eval_type))
0288 
0289 
0290 def read_udfs(pickleSer, infile, eval_type):
0291     runner_conf = {}
0292 
0293     if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
0294                      PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
0295                      PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
0296                      PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
0297                      PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
0298                      PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
0299                      PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
0300 
0301         # Load conf used for pandas_udf evaluation
0302         num_conf = read_int(infile)
0303         for i in range(num_conf):
0304             k = utf8_deserializer.loads(infile)
0305             v = utf8_deserializer.loads(infile)
0306             runner_conf[k] = v
0307 
0308         # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
0309         timezone = runner_conf.get("spark.sql.session.timeZone", None)
0310         safecheck = runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely",
0311                                     "false").lower() == 'true'
0312         # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType
0313         assign_cols_by_name = runner_conf.get(
0314             "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
0315             .lower() == "true"
0316 
0317         if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
0318             ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name)
0319         else:
0320             # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
0321             # pandas Series. See SPARK-27240.
0322             df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
0323                              eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or
0324                              eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
0325             ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name,
0326                                                  df_for_struct)
0327     else:
0328         ser = BatchedSerializer(PickleSerializer(), 100)
0329 
0330     num_udfs = read_int(infile)
0331 
0332     is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
0333     is_map_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
0334 
0335     if is_scalar_iter or is_map_iter:
0336         if is_scalar_iter:
0337             assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
0338         if is_map_iter:
0339             assert num_udfs == 1, "One MAP_ITER UDF expected here."
0340 
0341         arg_offsets, udf = read_single_udf(
0342             pickleSer, infile, eval_type, runner_conf, udf_index=0)
0343 
0344         def func(_, iterator):
0345             num_input_rows = [0]  # TODO(SPARK-29909): Use nonlocal after we drop Python 2.
0346 
0347             def map_batch(batch):
0348                 udf_args = [batch[offset] for offset in arg_offsets]
0349                 num_input_rows[0] += len(udf_args[0])
0350                 if len(udf_args) == 1:
0351                     return udf_args[0]
0352                 else:
0353                     return tuple(udf_args)
0354 
0355             iterator = map(map_batch, iterator)
0356             result_iter = udf(iterator)
0357 
0358             num_output_rows = 0
0359             for result_batch, result_type in result_iter:
0360                 num_output_rows += len(result_batch)
0361                 # This assert is for Scalar Iterator UDF to fail fast.
0362                 # The length of the entire input can only be explicitly known
0363                 # by consuming the input iterator in user side. Therefore,
0364                 # it's very unlikely the output length is higher than
0365                 # input length.
0366                 assert is_map_iter or num_output_rows <= num_input_rows[0], \
0367                     "Pandas SCALAR_ITER UDF outputted more rows than input rows."
0368                 yield (result_batch, result_type)
0369 
0370             if is_scalar_iter:
0371                 try:
0372                     next(iterator)
0373                 except StopIteration:
0374                     pass
0375                 else:
0376                     raise RuntimeError("pandas iterator UDF should exhaust the input "
0377                                        "iterator.")
0378 
0379                 if num_output_rows != num_input_rows[0]:
0380                     raise RuntimeError(
0381                         "The length of output in Scalar iterator pandas UDF should be "
0382                         "the same with the input's; however, the length of output was %d and the "
0383                         "length of input was %d." % (num_output_rows, num_input_rows[0]))
0384 
0385         # profiling is not supported for UDF
0386         return func, None, ser, ser
0387 
0388     def extract_key_value_indexes(grouped_arg_offsets):
0389         """
0390         Helper function to extract the key and value indexes from arg_offsets for the grouped and
0391         cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code.
0392 
0393         :param grouped_arg_offsets:  List containing the key and value indexes of columns of the
0394             DataFrames to be passed to the udf. It consists of n repeating groups where n is the
0395             number of DataFrames.  Each group has the following format:
0396                 group[0]: length of group
0397                 group[1]: length of key indexes
0398                 group[2.. group[1] +2]: key attributes
0399                 group[group[1] +3 group[0]]: value attributes
0400         """
0401         parsed = []
0402         idx = 0
0403         while idx < len(grouped_arg_offsets):
0404             offsets_len = grouped_arg_offsets[idx]
0405             idx += 1
0406             offsets = grouped_arg_offsets[idx: idx + offsets_len]
0407             split_index = offsets[0] + 1
0408             offset_keys = offsets[1: split_index]
0409             offset_values = offsets[split_index:]
0410             parsed.append([offset_keys, offset_values])
0411             idx += offsets_len
0412         return parsed
0413 
0414     if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
0415         # We assume there is only one UDF here because grouped map doesn't
0416         # support combining multiple UDFs.
0417         assert num_udfs == 1
0418 
0419         # See FlatMapGroupsInPandasExec for how arg_offsets are used to
0420         # distinguish between grouping attributes and data attributes
0421         arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0)
0422         parsed_offsets = extract_key_value_indexes(arg_offsets)
0423 
0424         # Create function like this:
0425         #   mapper a: f([a[0]], [a[0], a[1]])
0426         def mapper(a):
0427             keys = [a[o] for o in parsed_offsets[0][0]]
0428             vals = [a[o] for o in parsed_offsets[0][1]]
0429             return f(keys, vals)
0430     elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
0431         # We assume there is only one UDF here because cogrouped map doesn't
0432         # support combining multiple UDFs.
0433         assert num_udfs == 1
0434         arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0)
0435 
0436         parsed_offsets = extract_key_value_indexes(arg_offsets)
0437 
0438         def mapper(a):
0439             df1_keys = [a[0][o] for o in parsed_offsets[0][0]]
0440             df1_vals = [a[0][o] for o in parsed_offsets[0][1]]
0441             df2_keys = [a[1][o] for o in parsed_offsets[1][0]]
0442             df2_vals = [a[1][o] for o in parsed_offsets[1][1]]
0443             return f(df1_keys, df1_vals, df2_keys, df2_vals)
0444     else:
0445         udfs = []
0446         for i in range(num_udfs):
0447             udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i))
0448 
0449         def mapper(a):
0450             result = tuple(f(*[a[o] for o in arg_offsets]) for (arg_offsets, f) in udfs)
0451             # In the special case of a single UDF this will return a single result rather
0452             # than a tuple of results; this is the format that the JVM side expects.
0453             if len(result) == 1:
0454                 return result[0]
0455             else:
0456                 return result
0457 
0458     func = lambda _, it: map(mapper, it)
0459 
0460     # profiling is not supported for UDF
0461     return func, None, ser, ser
0462 
0463 
0464 def main(infile, outfile):
0465     try:
0466         boot_time = time.time()
0467         split_index = read_int(infile)
0468         if split_index == -1:  # for unit tests
0469             sys.exit(-1)
0470 
0471         version = utf8_deserializer.loads(infile)
0472         if version != "%d.%d" % sys.version_info[:2]:
0473             raise Exception(("Python in worker has different version %s than that in " +
0474                              "driver %s, PySpark cannot run with different minor versions." +
0475                              "Please check environment variables PYSPARK_PYTHON and " +
0476                              "PYSPARK_DRIVER_PYTHON are correctly set.") %
0477                             ("%d.%d" % sys.version_info[:2], version))
0478 
0479         # read inputs only for a barrier task
0480         isBarrier = read_bool(infile)
0481         boundPort = read_int(infile)
0482         secret = UTF8Deserializer().loads(infile)
0483 
0484         # set up memory limits
0485         memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1"))
0486         if memory_limit_mb > 0 and has_resource_module:
0487             total_memory = resource.RLIMIT_AS
0488             try:
0489                 (soft_limit, hard_limit) = resource.getrlimit(total_memory)
0490                 msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit)
0491                 print(msg, file=sys.stderr)
0492 
0493                 # convert to bytes
0494                 new_limit = memory_limit_mb * 1024 * 1024
0495 
0496                 if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit:
0497                     msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit)
0498                     print(msg, file=sys.stderr)
0499                     resource.setrlimit(total_memory, (new_limit, new_limit))
0500 
0501             except (resource.error, OSError, ValueError) as e:
0502                 # not all systems support resource limits, so warn instead of failing
0503                 print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr)
0504 
0505         # initialize global state
0506         taskContext = None
0507         if isBarrier:
0508             taskContext = BarrierTaskContext._getOrCreate()
0509             BarrierTaskContext._initialize(boundPort, secret)
0510             # Set the task context instance here, so we can get it by TaskContext.get for
0511             # both TaskContext and BarrierTaskContext
0512             TaskContext._setTaskContext(taskContext)
0513         else:
0514             taskContext = TaskContext._getOrCreate()
0515         # read inputs for TaskContext info
0516         taskContext._stageId = read_int(infile)
0517         taskContext._partitionId = read_int(infile)
0518         taskContext._attemptNumber = read_int(infile)
0519         taskContext._taskAttemptId = read_long(infile)
0520         taskContext._resources = {}
0521         for r in range(read_int(infile)):
0522             key = utf8_deserializer.loads(infile)
0523             name = utf8_deserializer.loads(infile)
0524             addresses = []
0525             taskContext._resources = {}
0526             for a in range(read_int(infile)):
0527                 addresses.append(utf8_deserializer.loads(infile))
0528             taskContext._resources[key] = ResourceInformation(name, addresses)
0529 
0530         taskContext._localProperties = dict()
0531         for i in range(read_int(infile)):
0532             k = utf8_deserializer.loads(infile)
0533             v = utf8_deserializer.loads(infile)
0534             taskContext._localProperties[k] = v
0535 
0536         shuffle.MemoryBytesSpilled = 0
0537         shuffle.DiskBytesSpilled = 0
0538         _accumulatorRegistry.clear()
0539 
0540         # fetch name of workdir
0541         spark_files_dir = utf8_deserializer.loads(infile)
0542         SparkFiles._root_directory = spark_files_dir
0543         SparkFiles._is_running_on_worker = True
0544 
0545         # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
0546         add_path(spark_files_dir)  # *.py files that were added will be copied here
0547         num_python_includes = read_int(infile)
0548         for _ in range(num_python_includes):
0549             filename = utf8_deserializer.loads(infile)
0550             add_path(os.path.join(spark_files_dir, filename))
0551         if sys.version > '3':
0552             import importlib
0553             importlib.invalidate_caches()
0554 
0555         # fetch names and values of broadcast variables
0556         needs_broadcast_decryption_server = read_bool(infile)
0557         num_broadcast_variables = read_int(infile)
0558         if needs_broadcast_decryption_server:
0559             # read the decrypted data from a server in the jvm
0560             port = read_int(infile)
0561             auth_secret = utf8_deserializer.loads(infile)
0562             (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)
0563 
0564         for _ in range(num_broadcast_variables):
0565             bid = read_long(infile)
0566             if bid >= 0:
0567                 if needs_broadcast_decryption_server:
0568                     read_bid = read_long(broadcast_sock_file)
0569                     assert(read_bid == bid)
0570                     _broadcastRegistry[bid] = \
0571                         Broadcast(sock_file=broadcast_sock_file)
0572                 else:
0573                     path = utf8_deserializer.loads(infile)
0574                     _broadcastRegistry[bid] = Broadcast(path=path)
0575 
0576             else:
0577                 bid = - bid - 1
0578                 _broadcastRegistry.pop(bid)
0579 
0580         if needs_broadcast_decryption_server:
0581             broadcast_sock_file.write(b'1')
0582             broadcast_sock_file.close()
0583 
0584         _accumulatorRegistry.clear()
0585         eval_type = read_int(infile)
0586         if eval_type == PythonEvalType.NON_UDF:
0587             func, profiler, deserializer, serializer = read_command(pickleSer, infile)
0588         else:
0589             func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)
0590 
0591         init_time = time.time()
0592 
0593         def process():
0594             iterator = deserializer.load_stream(infile)
0595             out_iter = func(split_index, iterator)
0596             try:
0597                 serializer.dump_stream(out_iter, outfile)
0598             finally:
0599                 if hasattr(out_iter, 'close'):
0600                     out_iter.close()
0601 
0602         if profiler:
0603             profiler.profile(process)
0604         else:
0605             process()
0606 
0607         # Reset task context to None. This is a guard code to avoid residual context when worker
0608         # reuse.
0609         TaskContext._setTaskContext(None)
0610         BarrierTaskContext._setTaskContext(None)
0611     except Exception:
0612         try:
0613             exc_info = traceback.format_exc()
0614             if isinstance(exc_info, bytes):
0615                 # exc_info may contains other encoding bytes, replace the invalid bytes and convert
0616                 # it back to utf-8 again
0617                 exc_info = exc_info.decode("utf-8", "replace").encode("utf-8")
0618             else:
0619                 exc_info = exc_info.encode("utf-8")
0620             write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
0621             write_with_length(exc_info, outfile)
0622         except IOError:
0623             # JVM close the socket
0624             pass
0625         except Exception:
0626             # Write the error to stderr if it happened while serializing
0627             print("PySpark worker failed with exception:", file=sys.stderr)
0628             print(traceback.format_exc(), file=sys.stderr)
0629         sys.exit(-1)
0630     finish_time = time.time()
0631     report_times(outfile, boot_time, init_time, finish_time)
0632     write_long(shuffle.MemoryBytesSpilled, outfile)
0633     write_long(shuffle.DiskBytesSpilled, outfile)
0634 
0635     # Mark the beginning of the accumulators section of the output
0636     write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
0637     write_int(len(_accumulatorRegistry), outfile)
0638     for (aid, accum) in _accumulatorRegistry.items():
0639         pickleSer._write_with_length((aid, accum._value), outfile)
0640 
0641     # check end of stream
0642     if read_int(infile) == SpecialLengths.END_OF_STREAM:
0643         write_int(SpecialLengths.END_OF_STREAM, outfile)
0644     else:
0645         # write a different value to tell JVM to not reuse this worker
0646         write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
0647         sys.exit(-1)
0648 
0649 
0650 if __name__ == '__main__':
0651     # Read information about how to connect back to the JVM from the environment.
0652     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
0653     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
0654     (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
0655     main(sock_file, sock_file)