0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import atexit
0019 import os
0020 import sys
0021 import signal
0022 import shlex
0023 import shutil
0024 import socket
0025 import platform
0026 import tempfile
0027 import time
0028 from subprocess import Popen, PIPE
0029
0030 if sys.version >= '3':
0031 xrange = range
0032
0033 from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters
0034 from py4j.clientserver import ClientServer, JavaParameters, PythonParameters
0035 from pyspark.find_spark_home import _find_spark_home
0036 from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
0037 from pyspark.util import _exception_message
0038
0039
0040 def launch_gateway(conf=None, popen_kwargs=None):
0041 """
0042 launch jvm gateway
0043 :param conf: spark configuration passed to spark-submit
0044 :param popen_kwargs: Dictionary of kwargs to pass to Popen when spawning
0045 the py4j JVM. This is a developer feature intended for use in
0046 customizing how pyspark interacts with the py4j JVM (e.g., capturing
0047 stdout/stderr).
0048 :return:
0049 """
0050 if "PYSPARK_GATEWAY_PORT" in os.environ:
0051 gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
0052 gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
0053
0054 proc = None
0055 else:
0056 SPARK_HOME = _find_spark_home()
0057
0058
0059 on_windows = platform.system() == "Windows"
0060 script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
0061 command = [os.path.join(SPARK_HOME, script)]
0062 if conf:
0063 for k, v in conf.getAll():
0064 command += ['--conf', '%s=%s' % (k, v)]
0065 submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
0066 if os.environ.get("SPARK_TESTING"):
0067 submit_args = ' '.join([
0068 "--conf spark.ui.enabled=false",
0069 submit_args
0070 ])
0071 command = command + shlex.split(submit_args)
0072
0073
0074
0075 conn_info_dir = tempfile.mkdtemp()
0076 try:
0077 fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
0078 os.close(fd)
0079 os.unlink(conn_info_file)
0080
0081 env = dict(os.environ)
0082 env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
0083
0084
0085 popen_kwargs = {} if popen_kwargs is None else popen_kwargs
0086
0087 popen_kwargs['stdin'] = PIPE
0088
0089 popen_kwargs['env'] = env
0090 if not on_windows:
0091
0092 def preexec_func():
0093 signal.signal(signal.SIGINT, signal.SIG_IGN)
0094 popen_kwargs['preexec_fn'] = preexec_func
0095 proc = Popen(command, **popen_kwargs)
0096 else:
0097
0098 proc = Popen(command, **popen_kwargs)
0099
0100
0101 while not proc.poll() and not os.path.isfile(conn_info_file):
0102 time.sleep(0.1)
0103
0104 if not os.path.isfile(conn_info_file):
0105 raise Exception("Java gateway process exited before sending its port number")
0106
0107 with open(conn_info_file, "rb") as info:
0108 gateway_port = read_int(info)
0109 gateway_secret = UTF8Deserializer().loads(info)
0110 finally:
0111 shutil.rmtree(conn_info_dir)
0112
0113
0114
0115
0116
0117
0118
0119 if on_windows:
0120
0121
0122
0123
0124
0125 def killChild():
0126 Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
0127 atexit.register(killChild)
0128
0129
0130 if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true":
0131 gateway = ClientServer(
0132 java_parameters=JavaParameters(
0133 port=gateway_port,
0134 auth_token=gateway_secret,
0135 auto_convert=True),
0136 python_parameters=PythonParameters(
0137 port=0,
0138 eager_load=False))
0139 else:
0140 gateway = JavaGateway(
0141 gateway_parameters=GatewayParameters(
0142 port=gateway_port,
0143 auth_token=gateway_secret,
0144 auto_convert=True))
0145
0146
0147 gateway.proc = proc
0148
0149
0150 java_import(gateway.jvm, "org.apache.spark.SparkConf")
0151 java_import(gateway.jvm, "org.apache.spark.api.java.*")
0152 java_import(gateway.jvm, "org.apache.spark.api.python.*")
0153 java_import(gateway.jvm, "org.apache.spark.ml.python.*")
0154 java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
0155
0156 java_import(gateway.jvm, "org.apache.spark.sql.*")
0157 java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
0158 java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
0159 java_import(gateway.jvm, "scala.Tuple2")
0160
0161 return gateway
0162
0163
0164 def _do_server_auth(conn, auth_secret):
0165 """
0166 Performs the authentication protocol defined by the SocketAuthHelper class on the given
0167 file-like object 'conn'.
0168 """
0169 write_with_length(auth_secret.encode("utf-8"), conn)
0170 conn.flush()
0171 reply = UTF8Deserializer().loads(conn)
0172 if reply != "ok":
0173 conn.close()
0174 raise Exception("Unexpected reply from iterator server.")
0175
0176
0177 def local_connect_and_auth(port, auth_secret):
0178 """
0179 Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
0180 Handles IPV4 & IPV6, does some error handling.
0181 :param port
0182 :param auth_secret
0183 :return: a tuple with (sockfile, sock)
0184 """
0185 sock = None
0186 errors = []
0187
0188
0189 for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
0190 af, socktype, proto, _, sa = res
0191 try:
0192 sock = socket.socket(af, socktype, proto)
0193 sock.settimeout(15)
0194 sock.connect(sa)
0195 sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
0196 _do_server_auth(sockfile, auth_secret)
0197 return (sockfile, sock)
0198 except socket.error as e:
0199 emsg = _exception_message(e)
0200 errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
0201 sock.close()
0202 sock = None
0203 raise Exception("could not open socket: %s" % errors)
0204
0205
0206 def ensure_callback_server_started(gw):
0207 """
0208 Start callback server if not already started. The callback server is needed if the Java
0209 driver process needs to callback into the Python driver process to execute Python code.
0210 """
0211
0212
0213 if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
0214 gw.callback_server_parameters.eager_load = True
0215 gw.callback_server_parameters.daemonize = True
0216 gw.callback_server_parameters.daemonize_connections = True
0217 gw.callback_server_parameters.port = 0
0218 gw.start_callback_server(gw.callback_server_parameters)
0219 cbport = gw._callback_server.server_socket.getsockname()[1]
0220 gw._callback_server.port = cbport
0221
0222 gw._python_proxy_port = gw._callback_server.port
0223
0224 jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
0225
0226 jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port)