Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import atexit
0019 import os
0020 import sys
0021 import signal
0022 import shlex
0023 import shutil
0024 import socket
0025 import platform
0026 import tempfile
0027 import time
0028 from subprocess import Popen, PIPE
0029 
0030 if sys.version >= '3':
0031     xrange = range
0032 
0033 from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters
0034 from py4j.clientserver import ClientServer, JavaParameters, PythonParameters
0035 from pyspark.find_spark_home import _find_spark_home
0036 from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
0037 from pyspark.util import _exception_message
0038 
0039 
0040 def launch_gateway(conf=None, popen_kwargs=None):
0041     """
0042     launch jvm gateway
0043     :param conf: spark configuration passed to spark-submit
0044     :param popen_kwargs: Dictionary of kwargs to pass to Popen when spawning
0045         the py4j JVM. This is a developer feature intended for use in
0046         customizing how pyspark interacts with the py4j JVM (e.g., capturing
0047         stdout/stderr).
0048     :return:
0049     """
0050     if "PYSPARK_GATEWAY_PORT" in os.environ:
0051         gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
0052         gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
0053         # Process already exists
0054         proc = None
0055     else:
0056         SPARK_HOME = _find_spark_home()
0057         # Launch the Py4j gateway using Spark's run command so that we pick up the
0058         # proper classpath and settings from spark-env.sh
0059         on_windows = platform.system() == "Windows"
0060         script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
0061         command = [os.path.join(SPARK_HOME, script)]
0062         if conf:
0063             for k, v in conf.getAll():
0064                 command += ['--conf', '%s=%s' % (k, v)]
0065         submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
0066         if os.environ.get("SPARK_TESTING"):
0067             submit_args = ' '.join([
0068                 "--conf spark.ui.enabled=false",
0069                 submit_args
0070             ])
0071         command = command + shlex.split(submit_args)
0072 
0073         # Create a temporary directory where the gateway server should write the connection
0074         # information.
0075         conn_info_dir = tempfile.mkdtemp()
0076         try:
0077             fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
0078             os.close(fd)
0079             os.unlink(conn_info_file)
0080 
0081             env = dict(os.environ)
0082             env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
0083 
0084             # Launch the Java gateway.
0085             popen_kwargs = {} if popen_kwargs is None else popen_kwargs
0086             # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
0087             popen_kwargs['stdin'] = PIPE
0088             # We always set the necessary environment variables.
0089             popen_kwargs['env'] = env
0090             if not on_windows:
0091                 # Don't send ctrl-c / SIGINT to the Java gateway:
0092                 def preexec_func():
0093                     signal.signal(signal.SIGINT, signal.SIG_IGN)
0094                 popen_kwargs['preexec_fn'] = preexec_func
0095                 proc = Popen(command, **popen_kwargs)
0096             else:
0097                 # preexec_fn not supported on Windows
0098                 proc = Popen(command, **popen_kwargs)
0099 
0100             # Wait for the file to appear, or for the process to exit, whichever happens first.
0101             while not proc.poll() and not os.path.isfile(conn_info_file):
0102                 time.sleep(0.1)
0103 
0104             if not os.path.isfile(conn_info_file):
0105                 raise Exception("Java gateway process exited before sending its port number")
0106 
0107             with open(conn_info_file, "rb") as info:
0108                 gateway_port = read_int(info)
0109                 gateway_secret = UTF8Deserializer().loads(info)
0110         finally:
0111             shutil.rmtree(conn_info_dir)
0112 
0113         # In Windows, ensure the Java child processes do not linger after Python has exited.
0114         # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
0115         # the parent process' stdin sends an EOF). In Windows, however, this is not possible
0116         # because java.lang.Process reads directly from the parent process' stdin, contending
0117         # with any opportunity to read an EOF from the parent. Note that this is only best
0118         # effort and will not take effect if the python process is violently terminated.
0119         if on_windows:
0120             # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
0121             # (because the UNIX "exec" command is not available). This means we cannot simply
0122             # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
0123             # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
0124             # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
0125             def killChild():
0126                 Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
0127             atexit.register(killChild)
0128 
0129     # Connect to the gateway (or client server to pin the thread between JVM and Python)
0130     if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true":
0131         gateway = ClientServer(
0132             java_parameters=JavaParameters(
0133                 port=gateway_port,
0134                 auth_token=gateway_secret,
0135                 auto_convert=True),
0136             python_parameters=PythonParameters(
0137                 port=0,
0138                 eager_load=False))
0139     else:
0140         gateway = JavaGateway(
0141             gateway_parameters=GatewayParameters(
0142                 port=gateway_port,
0143                 auth_token=gateway_secret,
0144                 auto_convert=True))
0145 
0146     # Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr)
0147     gateway.proc = proc
0148 
0149     # Import the classes used by PySpark
0150     java_import(gateway.jvm, "org.apache.spark.SparkConf")
0151     java_import(gateway.jvm, "org.apache.spark.api.java.*")
0152     java_import(gateway.jvm, "org.apache.spark.api.python.*")
0153     java_import(gateway.jvm, "org.apache.spark.ml.python.*")
0154     java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
0155     # TODO(davies): move into sql
0156     java_import(gateway.jvm, "org.apache.spark.sql.*")
0157     java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
0158     java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
0159     java_import(gateway.jvm, "scala.Tuple2")
0160 
0161     return gateway
0162 
0163 
0164 def _do_server_auth(conn, auth_secret):
0165     """
0166     Performs the authentication protocol defined by the SocketAuthHelper class on the given
0167     file-like object 'conn'.
0168     """
0169     write_with_length(auth_secret.encode("utf-8"), conn)
0170     conn.flush()
0171     reply = UTF8Deserializer().loads(conn)
0172     if reply != "ok":
0173         conn.close()
0174         raise Exception("Unexpected reply from iterator server.")
0175 
0176 
0177 def local_connect_and_auth(port, auth_secret):
0178     """
0179     Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
0180     Handles IPV4 & IPV6, does some error handling.
0181     :param port
0182     :param auth_secret
0183     :return: a tuple with (sockfile, sock)
0184     """
0185     sock = None
0186     errors = []
0187     # Support for both IPv4 and IPv6.
0188     # On most of IPv6-ready systems, IPv6 will take precedence.
0189     for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
0190         af, socktype, proto, _, sa = res
0191         try:
0192             sock = socket.socket(af, socktype, proto)
0193             sock.settimeout(15)
0194             sock.connect(sa)
0195             sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
0196             _do_server_auth(sockfile, auth_secret)
0197             return (sockfile, sock)
0198         except socket.error as e:
0199             emsg = _exception_message(e)
0200             errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
0201             sock.close()
0202             sock = None
0203     raise Exception("could not open socket: %s" % errors)
0204 
0205 
0206 def ensure_callback_server_started(gw):
0207     """
0208     Start callback server if not already started. The callback server is needed if the Java
0209     driver process needs to callback into the Python driver process to execute Python code.
0210     """
0211 
0212     # getattr will fallback to JVM, so we cannot test by hasattr()
0213     if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
0214         gw.callback_server_parameters.eager_load = True
0215         gw.callback_server_parameters.daemonize = True
0216         gw.callback_server_parameters.daemonize_connections = True
0217         gw.callback_server_parameters.port = 0
0218         gw.start_callback_server(gw.callback_server_parameters)
0219         cbport = gw._callback_server.server_socket.getsockname()[1]
0220         gw._callback_server.port = cbport
0221         # gateway with real port
0222         gw._python_proxy_port = gw._callback_server.port
0223         # get the GatewayServer object in JVM by ID
0224         jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
0225         # update the port of CallbackClient with real port
0226         jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port)