[PYSPARK] Update py4j to version 0.10.7.
This commit is contained in:
parent
94155d0395
commit
cc613b552e
2
LICENSE
2
LICENSE
|
@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
|
|||
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
|
||||
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
|
||||
(The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net)
|
||||
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/)
|
||||
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/)
|
||||
(Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/)
|
||||
(BSD licence) sbt and sbt-launch-lib.bash
|
||||
(BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE)
|
||||
|
|
|
@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh
|
|||
export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]"
|
||||
|
||||
# In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option
|
||||
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
|
||||
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
|
||||
# to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
|
||||
# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
|
||||
# and executor Python executables.
|
||||
|
||||
# Fail noisily if removed options are set
|
||||
if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then
|
||||
echo "Error in pyspark startup:"
|
||||
echo "Error in pyspark startup:"
|
||||
echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead."
|
||||
exit 1
|
||||
fi
|
||||
|
@ -57,7 +57,7 @@ export PYSPARK_PYTHON
|
|||
|
||||
# Add the PySpark classes to the Python path:
|
||||
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
|
||||
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH"
|
||||
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH"
|
||||
|
||||
# Load the PySpark shell.py script when ./pyspark is used interactively:
|
||||
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
|
||||
|
|
|
@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
|
|||
)
|
||||
|
||||
set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
|
||||
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH%
|
||||
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH%
|
||||
|
||||
set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
|
||||
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py
|
||||
|
|
|
@ -350,7 +350,7 @@
|
|||
<dependency>
|
||||
<groupId>net.sf.py4j</groupId>
|
||||
<artifactId>py4j</artifactId>
|
||||
<version>0.10.6</version>
|
||||
<version>0.10.7</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
|
|
|
@ -17,15 +17,10 @@
|
|||
|
||||
package org.apache.spark
|
||||
|
||||
import java.lang.{Byte => JByte}
|
||||
import java.net.{Authenticator, PasswordAuthentication}
|
||||
import java.nio.charset.StandardCharsets.UTF_8
|
||||
import java.security.{KeyStore, SecureRandom}
|
||||
import java.security.cert.X509Certificate
|
||||
import javax.net.ssl._
|
||||
|
||||
import com.google.common.hash.HashCodes
|
||||
import com.google.common.io.Files
|
||||
import org.apache.hadoop.io.Text
|
||||
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
|
||||
|
||||
|
@ -365,13 +360,8 @@ private[spark] class SecurityManager(
|
|||
return
|
||||
}
|
||||
|
||||
val rnd = new SecureRandom()
|
||||
val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE
|
||||
val secretBytes = new Array[Byte](length)
|
||||
rnd.nextBytes(secretBytes)
|
||||
|
||||
secretKey = Utils.createSecret(sparkConf)
|
||||
val creds = new Credentials()
|
||||
secretKey = HashCodes.fromBytes(secretBytes).toString()
|
||||
creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8))
|
||||
UserGroupInformation.getCurrentUser().addCredentials(creds)
|
||||
}
|
||||
|
|
|
@ -17,26 +17,39 @@
|
|||
|
||||
package org.apache.spark.api.python
|
||||
|
||||
import java.io.DataOutputStream
|
||||
import java.net.Socket
|
||||
import java.io.{DataOutputStream, File, FileOutputStream}
|
||||
import java.net.InetAddress
|
||||
import java.nio.charset.StandardCharsets.UTF_8
|
||||
import java.nio.file.Files
|
||||
|
||||
import py4j.GatewayServer
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
|
||||
* back to its caller via a callback port specified by the caller.
|
||||
* Process that starts a Py4J GatewayServer on an ephemeral port.
|
||||
*
|
||||
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
|
||||
*/
|
||||
private[spark] object PythonGatewayServer extends Logging {
|
||||
initializeLogIfNecessary(true)
|
||||
|
||||
def main(args: Array[String]): Unit = Utils.tryOrExit {
|
||||
// Start a GatewayServer on an ephemeral port
|
||||
val gatewayServer: GatewayServer = new GatewayServer(null, 0)
|
||||
def main(args: Array[String]): Unit = {
|
||||
val secret = Utils.createSecret(new SparkConf())
|
||||
|
||||
// Start a GatewayServer on an ephemeral port. Make sure the callback client is configured
|
||||
// with the same secret, in case the app needs callbacks from the JVM to the underlying
|
||||
// python processes.
|
||||
val localhost = InetAddress.getLoopbackAddress()
|
||||
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
|
||||
.authToken(secret)
|
||||
.javaPort(0)
|
||||
.javaAddress(localhost)
|
||||
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
|
||||
.build()
|
||||
|
||||
gatewayServer.start()
|
||||
val boundPort: Int = gatewayServer.getListeningPort
|
||||
if (boundPort == -1) {
|
||||
|
@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging {
|
|||
logDebug(s"Started PythonGatewayServer on port $boundPort")
|
||||
}
|
||||
|
||||
// Communicate the bound port back to the caller via the caller-specified callback port
|
||||
val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
|
||||
val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
|
||||
logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
|
||||
val callbackSocket = new Socket(callbackHost, callbackPort)
|
||||
val dos = new DataOutputStream(callbackSocket.getOutputStream)
|
||||
// Communicate the connection information back to the python process by writing the
|
||||
// information in the requested file. This needs to match the read side in java_gateway.py.
|
||||
val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH"))
|
||||
val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
|
||||
"connection", ".info").toFile()
|
||||
|
||||
val dos = new DataOutputStream(new FileOutputStream(tmpPath))
|
||||
dos.writeInt(boundPort)
|
||||
|
||||
val secretBytes = secret.getBytes(UTF_8)
|
||||
dos.writeInt(secretBytes.length)
|
||||
dos.write(secretBytes, 0, secretBytes.length)
|
||||
dos.close()
|
||||
callbackSocket.close()
|
||||
|
||||
if (!tmpPath.renameTo(connectionInfoPath)) {
|
||||
logError(s"Unable to write connection information to $connectionInfoPath.")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
|
||||
while (System.in.read() != -1) {
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast
|
|||
import org.apache.spark.input.PortableDataStream
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.security.SocketAuthHelper
|
||||
import org.apache.spark.util._
|
||||
|
||||
|
||||
|
@ -107,6 +108,12 @@ private[spark] object PythonRDD extends Logging {
|
|||
// remember the broadcasts sent to each worker
|
||||
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
|
||||
|
||||
// Authentication helper used when serving iterator data.
|
||||
private lazy val authHelper = {
|
||||
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
|
||||
new SocketAuthHelper(conf)
|
||||
}
|
||||
|
||||
def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
|
||||
synchronized {
|
||||
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
|
||||
|
@ -129,12 +136,13 @@ private[spark] object PythonRDD extends Logging {
|
|||
* (effectively a collect()), but allows you to run on a certain subset of partitions,
|
||||
* or to enable local execution.
|
||||
*
|
||||
* @return the port number of a local socket which serves the data collected from this job.
|
||||
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
|
||||
* data collected from this job, and the secret for authentication.
|
||||
*/
|
||||
def runJob(
|
||||
sc: SparkContext,
|
||||
rdd: JavaRDD[Array[Byte]],
|
||||
partitions: JArrayList[Int]): Int = {
|
||||
partitions: JArrayList[Int]): Array[Any] = {
|
||||
type ByteArray = Array[Byte]
|
||||
type UnrolledPartition = Array[ByteArray]
|
||||
val allPartitions: Array[UnrolledPartition] =
|
||||
|
@ -147,13 +155,14 @@ private[spark] object PythonRDD extends Logging {
|
|||
/**
|
||||
* A helper function to collect an RDD as an iterator, then serve it via socket.
|
||||
*
|
||||
* @return the port number of a local socket which serves the data collected from this job.
|
||||
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
|
||||
* data collected from this job, and the secret for authentication.
|
||||
*/
|
||||
def collectAndServe[T](rdd: RDD[T]): Int = {
|
||||
def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
|
||||
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
|
||||
}
|
||||
|
||||
def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
|
||||
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
|
||||
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
|
||||
}
|
||||
|
||||
|
@ -384,8 +393,11 @@ private[spark] object PythonRDD extends Logging {
|
|||
* and send them into this connection.
|
||||
*
|
||||
* The thread will terminate after all the data are sent or any exceptions happen.
|
||||
*
|
||||
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
|
||||
* data collected from this job, and the secret for authentication.
|
||||
*/
|
||||
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
|
||||
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
|
||||
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
|
||||
// Close the socket if no connection in 15 seconds
|
||||
serverSocket.setSoTimeout(15000)
|
||||
|
@ -395,11 +407,14 @@ private[spark] object PythonRDD extends Logging {
|
|||
override def run() {
|
||||
try {
|
||||
val sock = serverSocket.accept()
|
||||
authHelper.authClient(sock)
|
||||
|
||||
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
|
||||
Utils.tryWithSafeFinally {
|
||||
writeIteratorToStream(items, out)
|
||||
} {
|
||||
out.close()
|
||||
sock.close()
|
||||
}
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
|
@ -410,7 +425,7 @@ private[spark] object PythonRDD extends Logging {
|
|||
}
|
||||
}.start()
|
||||
|
||||
serverSocket.getLocalPort
|
||||
Array(serverSocket.getLocalPort, authHelper.secret)
|
||||
}
|
||||
|
||||
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
|
||||
|
|
|
@ -32,7 +32,7 @@ private[spark] object PythonUtils {
|
|||
val pythonPath = new ArrayBuffer[String]
|
||||
for (sparkHome <- sys.env.get("SPARK_HOME")) {
|
||||
pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator)
|
||||
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator)
|
||||
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator)
|
||||
}
|
||||
pythonPath ++= SparkContext.jarOfObject(this)
|
||||
pythonPath.mkString(File.pathSeparator)
|
||||
|
|
|
@ -27,6 +27,7 @@ import scala.collection.mutable
|
|||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.security.SocketAuthHelper
|
||||
import org.apache.spark.util.{RedirectThread, Utils}
|
||||
|
||||
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
|
||||
|
@ -67,6 +68,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
value
|
||||
}.getOrElse("pyspark.worker")
|
||||
|
||||
private val authHelper = new SocketAuthHelper(SparkEnv.get.conf)
|
||||
|
||||
var daemon: Process = null
|
||||
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
|
||||
var daemonPort: Int = 0
|
||||
|
@ -108,6 +111,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
if (pid < 0) {
|
||||
throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
|
||||
}
|
||||
|
||||
authHelper.authToServer(socket)
|
||||
daemonWorkers.put(socket, pid)
|
||||
socket
|
||||
}
|
||||
|
@ -145,25 +150,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
workerEnv.put("PYTHONPATH", pythonPath)
|
||||
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
|
||||
workerEnv.put("PYTHONUNBUFFERED", "YES")
|
||||
workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString)
|
||||
workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
|
||||
val worker = pb.start()
|
||||
|
||||
// Redirect worker stdout and stderr
|
||||
redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream)
|
||||
|
||||
// Tell the worker our port
|
||||
val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8)
|
||||
out.write(serverSocket.getLocalPort + "\n")
|
||||
out.flush()
|
||||
|
||||
// Wait for it to connect to our socket
|
||||
// Wait for it to connect to our socket, and validate the auth secret.
|
||||
serverSocket.setSoTimeout(10000)
|
||||
|
||||
try {
|
||||
val socket = serverSocket.accept()
|
||||
authHelper.authClient(socket)
|
||||
simpleWorkers.put(socket, worker)
|
||||
return socket
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
throw new SparkException("Python worker did not connect back in time", e)
|
||||
throw new SparkException("Python worker failed to connect back.", e)
|
||||
}
|
||||
} finally {
|
||||
if (serverSocket != null) {
|
||||
|
@ -187,6 +191,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
val workerEnv = pb.environment()
|
||||
workerEnv.putAll(envVars.asJava)
|
||||
workerEnv.put("PYTHONPATH", pythonPath)
|
||||
workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
|
||||
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
|
||||
workerEnv.put("PYTHONUNBUFFERED", "YES")
|
||||
daemon = pb.start()
|
||||
|
@ -218,7 +223,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
|
||||
// Redirect daemon stdout and stderr
|
||||
redirectStreamsToStderr(in, daemon.getErrorStream)
|
||||
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.deploy
|
||||
|
||||
import java.io.File
|
||||
import java.net.URI
|
||||
import java.net.{InetAddress, URI}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
@ -39,6 +39,7 @@ object PythonRunner {
|
|||
val pyFiles = args(1)
|
||||
val otherArgs = args.slice(2, args.length)
|
||||
val sparkConf = new SparkConf()
|
||||
val secret = Utils.createSecret(sparkConf)
|
||||
val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON)
|
||||
.orElse(sparkConf.get(PYSPARK_PYTHON))
|
||||
.orElse(sys.env.get("PYSPARK_DRIVER_PYTHON"))
|
||||
|
@ -51,7 +52,13 @@ object PythonRunner {
|
|||
|
||||
// Launch a Py4J gateway server for the process to connect to; this will let it see our
|
||||
// Java system properties and such
|
||||
val gatewayServer = new py4j.GatewayServer(null, 0)
|
||||
val localhost = InetAddress.getLoopbackAddress()
|
||||
val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder()
|
||||
.authToken(secret)
|
||||
.javaPort(0)
|
||||
.javaAddress(localhost)
|
||||
.callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
|
||||
.build()
|
||||
val thread = new Thread(new Runnable() {
|
||||
override def run(): Unit = Utils.logUncaughtExceptions {
|
||||
gatewayServer.start()
|
||||
|
@ -82,6 +89,7 @@ object PythonRunner {
|
|||
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
|
||||
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
|
||||
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
|
||||
env.put("PYSPARK_GATEWAY_SECRET", secret)
|
||||
// pass conf spark.pyspark.python to python process, the only way to pass info to
|
||||
// python process is through environment variable.
|
||||
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
|
||||
|
|
|
@ -352,6 +352,11 @@ package object config {
|
|||
.regexConf
|
||||
.createOptional
|
||||
|
||||
private[spark] val AUTH_SECRET_BIT_LENGTH =
|
||||
ConfigBuilder("spark.authenticate.secretBitLength")
|
||||
.intConf
|
||||
.createWithDefault(256)
|
||||
|
||||
private[spark] val NETWORK_AUTH_ENABLED =
|
||||
ConfigBuilder("spark.authenticate")
|
||||
.booleanConf
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.security
|
||||
|
||||
import java.io.{DataInputStream, DataOutputStream, InputStream}
|
||||
import java.net.Socket
|
||||
import java.nio.charset.StandardCharsets.UTF_8
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* A class that can be used to add a simple authentication protocol to socket-based communication.
|
||||
*
|
||||
* The protocol is simple: an auth secret is written to the socket, and the other side checks the
|
||||
* secret and writes either "ok" or "err" to the output. If authentication fails, the socket is
|
||||
* not expected to be valid anymore.
|
||||
*
|
||||
* There's no secrecy, so this relies on the sockets being either local or somehow encrypted.
|
||||
*/
|
||||
private[spark] class SocketAuthHelper(conf: SparkConf) {
|
||||
|
||||
val secret = Utils.createSecret(conf)
|
||||
|
||||
/**
|
||||
* Read the auth secret from the socket and compare to the expected value. Write the reply back
|
||||
* to the socket.
|
||||
*
|
||||
* If authentication fails, this method will close the socket.
|
||||
*
|
||||
* @param s The client socket.
|
||||
* @throws IllegalArgumentException If authentication fails.
|
||||
*/
|
||||
def authClient(s: Socket): Unit = {
|
||||
// Set the socket timeout while checking the auth secret. Reset it before returning.
|
||||
val currentTimeout = s.getSoTimeout()
|
||||
try {
|
||||
s.setSoTimeout(10000)
|
||||
val clientSecret = readUtf8(s)
|
||||
if (secret == clientSecret) {
|
||||
writeUtf8("ok", s)
|
||||
} else {
|
||||
writeUtf8("err", s)
|
||||
JavaUtils.closeQuietly(s)
|
||||
}
|
||||
} finally {
|
||||
s.setSoTimeout(currentTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Authenticate with a server by writing the auth secret and checking the server's reply.
|
||||
*
|
||||
* If authentication fails, this method will close the socket.
|
||||
*
|
||||
* @param s The socket connected to the server.
|
||||
* @throws IllegalArgumentException If authentication fails.
|
||||
*/
|
||||
def authToServer(s: Socket): Unit = {
|
||||
writeUtf8(secret, s)
|
||||
|
||||
val reply = readUtf8(s)
|
||||
if (reply != "ok") {
|
||||
JavaUtils.closeQuietly(s)
|
||||
throw new IllegalArgumentException("Authentication failed.")
|
||||
}
|
||||
}
|
||||
|
||||
protected def readUtf8(s: Socket): String = {
|
||||
val din = new DataInputStream(s.getInputStream())
|
||||
val len = din.readInt()
|
||||
val bytes = new Array[Byte](len)
|
||||
din.readFully(bytes)
|
||||
new String(bytes, UTF_8)
|
||||
}
|
||||
|
||||
protected def writeUtf8(str: String, s: Socket): Unit = {
|
||||
val bytes = str.getBytes(UTF_8)
|
||||
val dout = new DataOutputStream(s.getOutputStream())
|
||||
dout.writeInt(bytes.length)
|
||||
dout.write(bytes, 0, bytes.length)
|
||||
dout.flush()
|
||||
}
|
||||
|
||||
}
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.util
|
||||
|
||||
import java.io._
|
||||
import java.lang.{Byte => JByte}
|
||||
import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
|
||||
import java.lang.reflect.InvocationTargetException
|
||||
import java.math.{MathContext, RoundingMode}
|
||||
|
@ -26,11 +27,11 @@ import java.nio.ByteBuffer
|
|||
import java.nio.channels.{Channels, FileChannel}
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.nio.file.Files
|
||||
import java.security.SecureRandom
|
||||
import java.util.{Locale, Properties, Random, UUID}
|
||||
import java.util.concurrent._
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import java.util.zip.GZIPInputStream
|
||||
import javax.net.ssl.HttpsURLConnection
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.JavaConverters._
|
||||
|
@ -44,6 +45,7 @@ import scala.util.matching.Regex
|
|||
|
||||
import _root_.io.netty.channel.unix.Errors.NativeIoException
|
||||
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
|
||||
import com.google.common.hash.HashCodes
|
||||
import com.google.common.io.{ByteStreams, Files => GFiles}
|
||||
import com.google.common.net.InetAddresses
|
||||
import org.apache.commons.lang3.SystemUtils
|
||||
|
@ -2704,6 +2706,15 @@ private[spark] object Utils extends Logging {
|
|||
def substituteAppId(opt: String, appId: String): String = {
|
||||
opt.replace("{{APP_ID}}", appId)
|
||||
}
|
||||
|
||||
def createSecret(conf: SparkConf): String = {
|
||||
val bits = conf.get(AUTH_SECRET_BIT_LENGTH)
|
||||
val rnd = new SecureRandom()
|
||||
val secretBytes = new Array[Byte](bits / JByte.SIZE)
|
||||
rnd.nextBytes(secretBytes)
|
||||
HashCodes.fromBytes(secretBytes).toString()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private[util] object CallerContext extends Logging {
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.spark.security
|
||||
|
||||
import java.io.Closeable
|
||||
import java.net._
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkFunSuite}
|
||||
import org.apache.spark.internal.config._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class SocketAuthHelperSuite extends SparkFunSuite {
|
||||
|
||||
private val conf = new SparkConf()
|
||||
private val authHelper = new SocketAuthHelper(conf)
|
||||
|
||||
test("successful auth") {
|
||||
Utils.tryWithResource(new ServerThread()) { server =>
|
||||
Utils.tryWithResource(server.createClient()) { client =>
|
||||
authHelper.authToServer(client)
|
||||
server.close()
|
||||
server.join()
|
||||
assert(server.error == null)
|
||||
assert(server.authenticated)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("failed auth") {
|
||||
Utils.tryWithResource(new ServerThread()) { server =>
|
||||
Utils.tryWithResource(server.createClient()) { client =>
|
||||
val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128))
|
||||
intercept[IllegalArgumentException] {
|
||||
badHelper.authToServer(client)
|
||||
}
|
||||
server.close()
|
||||
server.join()
|
||||
assert(server.error != null)
|
||||
assert(!server.authenticated)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class ServerThread extends Thread with Closeable {
|
||||
|
||||
private val ss = new ServerSocket()
|
||||
ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0))
|
||||
|
||||
@volatile var error: Exception = _
|
||||
@volatile var authenticated = false
|
||||
|
||||
setDaemon(true)
|
||||
start()
|
||||
|
||||
def createClient(): Socket = {
|
||||
new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort())
|
||||
}
|
||||
|
||||
override def run(): Unit = {
|
||||
var clientConn: Socket = null
|
||||
try {
|
||||
clientConn = ss.accept()
|
||||
authHelper.authClient(clientConn)
|
||||
authenticated = true
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
error = e
|
||||
} finally {
|
||||
Option(clientConn).foreach(_.close())
|
||||
}
|
||||
}
|
||||
|
||||
override def close(): Unit = {
|
||||
try {
|
||||
ss.close()
|
||||
} finally {
|
||||
interrupt()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -170,7 +170,7 @@ parquet-hadoop-1.10.0.jar
|
|||
parquet-hadoop-bundle-1.6.0.jar
|
||||
parquet-jackson-1.10.0.jar
|
||||
protobuf-java-2.5.0.jar
|
||||
py4j-0.10.6.jar
|
||||
py4j-0.10.7.jar
|
||||
pyrolite-4.13.jar
|
||||
scala-compiler-2.11.8.jar
|
||||
scala-library-2.11.8.jar
|
||||
|
|
|
@ -171,7 +171,7 @@ parquet-hadoop-1.10.0.jar
|
|||
parquet-hadoop-bundle-1.6.0.jar
|
||||
parquet-jackson-1.10.0.jar
|
||||
protobuf-java-2.5.0.jar
|
||||
py4j-0.10.6.jar
|
||||
py4j-0.10.7.jar
|
||||
pyrolite-4.13.jar
|
||||
scala-compiler-2.11.8.jar
|
||||
scala-library-2.11.8.jar
|
||||
|
|
|
@ -189,7 +189,7 @@ parquet-hadoop-1.10.0.jar
|
|||
parquet-hadoop-bundle-1.6.0.jar
|
||||
parquet-jackson-1.10.0.jar
|
||||
protobuf-java-2.5.0.jar
|
||||
py4j-0.10.6.jar
|
||||
py4j-0.10.7.jar
|
||||
pyrolite-4.13.jar
|
||||
re2j-1.1.jar
|
||||
scala-compiler-2.11.8.jar
|
||||
|
|
|
@ -89,7 +89,7 @@ for python in "${PYTHON_EXECS[@]}"; do
|
|||
source "$VIRTUALENV_PATH"/bin/activate
|
||||
fi
|
||||
# Upgrade pip & friends if using virutal env
|
||||
if [ ! -n "USE_CONDA" ]; then
|
||||
if [ ! -n "$USE_CONDA" ]; then
|
||||
pip install --upgrade pip pypandoc wheel numpy
|
||||
fi
|
||||
|
||||
|
|
|
@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c
|
|||
|
||||
## Python Requirements
|
||||
|
||||
At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow).
|
||||
At its core PySpark depends on Py4J (currently version 0.10.7), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow).
|
||||
|
|
|
@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build
|
|||
PAPER ?=
|
||||
BUILDDIR ?= _build
|
||||
|
||||
export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip)
|
||||
export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip)
|
||||
|
||||
# User-friendly check for sphinx-build
|
||||
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
|
||||
|
|
Binary file not shown.
BIN
python/lib/py4j-0.10.7-src.zip
Normal file
BIN
python/lib/py4j-0.10.7-src.zip
Normal file
Binary file not shown.
|
@ -998,8 +998,8 @@ class SparkContext(object):
|
|||
# by runJob() in order to avoid having to pass a Python lambda into
|
||||
# SparkContext#runJob.
|
||||
mappedRDD = rdd.mapPartitions(partitionFunc)
|
||||
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
|
||||
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
|
||||
sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
|
||||
return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer))
|
||||
|
||||
def show_profiles(self):
|
||||
""" Print the profile stats to stdout """
|
||||
|
|
|
@ -29,7 +29,7 @@ from socket import AF_INET, SOCK_STREAM, SOMAXCONN
|
|||
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
|
||||
|
||||
from pyspark.worker import main as worker_main
|
||||
from pyspark.serializers import read_int, write_int
|
||||
from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
|
||||
|
||||
|
||||
def compute_real_exit_code(exit_code):
|
||||
|
@ -40,7 +40,7 @@ def compute_real_exit_code(exit_code):
|
|||
return 1
|
||||
|
||||
|
||||
def worker(sock):
|
||||
def worker(sock, authenticated):
|
||||
"""
|
||||
Called by a worker process after the fork().
|
||||
"""
|
||||
|
@ -56,6 +56,18 @@ def worker(sock):
|
|||
# otherwise writes also cause a seek that makes us miss data on the read side.
|
||||
infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536)
|
||||
outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536)
|
||||
|
||||
if not authenticated:
|
||||
client_secret = UTF8Deserializer().loads(infile)
|
||||
if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret:
|
||||
write_with_length("ok".encode("utf-8"), outfile)
|
||||
outfile.flush()
|
||||
else:
|
||||
write_with_length("err".encode("utf-8"), outfile)
|
||||
outfile.flush()
|
||||
sock.close()
|
||||
return 1
|
||||
|
||||
exit_code = 0
|
||||
try:
|
||||
worker_main(infile, outfile)
|
||||
|
@ -153,8 +165,11 @@ def manager():
|
|||
write_int(os.getpid(), outfile)
|
||||
outfile.flush()
|
||||
outfile.close()
|
||||
authenticated = False
|
||||
while True:
|
||||
code = worker(sock)
|
||||
code = worker(sock, authenticated)
|
||||
if code == 0:
|
||||
authenticated = True
|
||||
if not reuse or code:
|
||||
# wait for closing
|
||||
try:
|
||||
|
|
|
@ -21,16 +21,19 @@ import sys
|
|||
import select
|
||||
import signal
|
||||
import shlex
|
||||
import shutil
|
||||
import socket
|
||||
import platform
|
||||
import tempfile
|
||||
import time
|
||||
from subprocess import Popen, PIPE
|
||||
|
||||
if sys.version >= '3':
|
||||
xrange = range
|
||||
|
||||
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
|
||||
from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
|
||||
from pyspark.find_spark_home import _find_spark_home
|
||||
from pyspark.serializers import read_int
|
||||
from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
|
||||
|
||||
|
||||
def launch_gateway(conf=None):
|
||||
|
@ -41,6 +44,7 @@ def launch_gateway(conf=None):
|
|||
"""
|
||||
if "PYSPARK_GATEWAY_PORT" in os.environ:
|
||||
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
|
||||
gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
|
||||
else:
|
||||
SPARK_HOME = _find_spark_home()
|
||||
# Launch the Py4j gateway using Spark's run command so that we pick up the
|
||||
|
@ -59,40 +63,40 @@ def launch_gateway(conf=None):
|
|||
])
|
||||
command = command + shlex.split(submit_args)
|
||||
|
||||
# Start a socket that will be used by PythonGatewayServer to communicate its port to us
|
||||
callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
callback_socket.bind(('127.0.0.1', 0))
|
||||
callback_socket.listen(1)
|
||||
callback_host, callback_port = callback_socket.getsockname()
|
||||
env = dict(os.environ)
|
||||
env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
|
||||
env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)
|
||||
# Create a temporary directory where the gateway server should write the connection
|
||||
# information.
|
||||
conn_info_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
|
||||
os.close(fd)
|
||||
os.unlink(conn_info_file)
|
||||
|
||||
# Launch the Java gateway.
|
||||
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
|
||||
if not on_windows:
|
||||
# Don't send ctrl-c / SIGINT to the Java gateway:
|
||||
def preexec_func():
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
|
||||
else:
|
||||
# preexec_fn not supported on Windows
|
||||
proc = Popen(command, stdin=PIPE, env=env)
|
||||
env = dict(os.environ)
|
||||
env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
|
||||
|
||||
gateway_port = None
|
||||
# We use select() here in order to avoid blocking indefinitely if the subprocess dies
|
||||
# before connecting
|
||||
while gateway_port is None and proc.poll() is None:
|
||||
timeout = 1 # (seconds)
|
||||
readable, _, _ = select.select([callback_socket], [], [], timeout)
|
||||
if callback_socket in readable:
|
||||
gateway_connection = callback_socket.accept()[0]
|
||||
# Determine which ephemeral port the server started on:
|
||||
gateway_port = read_int(gateway_connection.makefile(mode="rb"))
|
||||
gateway_connection.close()
|
||||
callback_socket.close()
|
||||
if gateway_port is None:
|
||||
raise Exception("Java gateway process exited before sending the driver its port number")
|
||||
# Launch the Java gateway.
|
||||
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
|
||||
if not on_windows:
|
||||
# Don't send ctrl-c / SIGINT to the Java gateway:
|
||||
def preexec_func():
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
|
||||
else:
|
||||
# preexec_fn not supported on Windows
|
||||
proc = Popen(command, stdin=PIPE, env=env)
|
||||
|
||||
# Wait for the file to appear, or for the process to exit, whichever happens first.
|
||||
while not proc.poll() and not os.path.isfile(conn_info_file):
|
||||
time.sleep(0.1)
|
||||
|
||||
if not os.path.isfile(conn_info_file):
|
||||
raise Exception("Java gateway process exited before sending its port number")
|
||||
|
||||
with open(conn_info_file, "rb") as info:
|
||||
gateway_port = read_int(info)
|
||||
gateway_secret = UTF8Deserializer().loads(info)
|
||||
finally:
|
||||
shutil.rmtree(conn_info_dir)
|
||||
|
||||
# In Windows, ensure the Java child processes do not linger after Python has exited.
|
||||
# In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
|
||||
|
@ -111,7 +115,9 @@ def launch_gateway(conf=None):
|
|||
atexit.register(killChild)
|
||||
|
||||
# Connect to the gateway
|
||||
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
|
||||
gateway = JavaGateway(
|
||||
gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
|
||||
auto_convert=True))
|
||||
|
||||
# Import the classes used by PySpark
|
||||
java_import(gateway.jvm, "org.apache.spark.SparkConf")
|
||||
|
@ -126,3 +132,16 @@ def launch_gateway(conf=None):
|
|||
java_import(gateway.jvm, "scala.Tuple2")
|
||||
|
||||
return gateway
|
||||
|
||||
|
||||
def do_server_auth(conn, auth_secret):
|
||||
"""
|
||||
Performs the authentication protocol defined by the SocketAuthHelper class on the given
|
||||
file-like object 'conn'.
|
||||
"""
|
||||
write_with_length(auth_secret.encode("utf-8"), conn)
|
||||
conn.flush()
|
||||
reply = UTF8Deserializer().loads(conn)
|
||||
if reply != "ok":
|
||||
conn.close()
|
||||
raise Exception("Unexpected reply from iterator server.")
|
||||
|
|
|
@ -39,9 +39,11 @@ if sys.version > '3':
|
|||
else:
|
||||
from itertools import imap as map, ifilter as filter
|
||||
|
||||
from pyspark.java_gateway import do_server_auth
|
||||
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
|
||||
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
|
||||
PickleSerializer, pack_long, AutoBatchedSerializer
|
||||
PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
|
||||
UTF8Deserializer
|
||||
from pyspark.join import python_join, python_left_outer_join, \
|
||||
python_right_outer_join, python_full_outer_join, python_cogroup
|
||||
from pyspark.statcounter import StatCounter
|
||||
|
@ -136,7 +138,8 @@ def _parse_memory(s):
|
|||
return int(float(s[:-1]) * units[s[-1].lower()])
|
||||
|
||||
|
||||
def _load_from_socket(port, serializer):
|
||||
def _load_from_socket(sock_info, serializer):
|
||||
port, auth_secret = sock_info
|
||||
sock = None
|
||||
# Support for both IPv4 and IPv6.
|
||||
# On most of IPv6-ready systems, IPv6 will take precedence.
|
||||
|
@ -156,8 +159,12 @@ def _load_from_socket(port, serializer):
|
|||
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
|
||||
# operation, it will very possibly fail. See SPARK-18281.
|
||||
sock.settimeout(None)
|
||||
|
||||
sockfile = sock.makefile("rwb", 65536)
|
||||
do_server_auth(sockfile, auth_secret)
|
||||
|
||||
# The socket will be automatically closed when garbage-collected.
|
||||
return serializer.load_stream(sock.makefile("rb", 65536))
|
||||
return serializer.load_stream(sockfile)
|
||||
|
||||
|
||||
def ignore_unicode_prefix(f):
|
||||
|
@ -822,8 +829,8 @@ class RDD(object):
|
|||
to be small, as all the data is loaded into the driver's memory.
|
||||
"""
|
||||
with SCCallSiteSync(self.context) as css:
|
||||
port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
|
||||
return list(_load_from_socket(port, self._jrdd_deserializer))
|
||||
sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
|
||||
return list(_load_from_socket(sock_info, self._jrdd_deserializer))
|
||||
|
||||
def reduce(self, f):
|
||||
"""
|
||||
|
@ -2380,8 +2387,8 @@ class RDD(object):
|
|||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
"""
|
||||
with SCCallSiteSync(self.context) as css:
|
||||
port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
|
||||
return _load_from_socket(port, self._jrdd_deserializer)
|
||||
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
|
||||
return _load_from_socket(sock_info, self._jrdd_deserializer)
|
||||
|
||||
|
||||
def _prepare_for_python_RDD(sc, command):
|
||||
|
|
|
@ -463,8 +463,8 @@ class DataFrame(object):
|
|||
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
|
||||
"""
|
||||
with SCCallSiteSync(self._sc) as css:
|
||||
port = self._jdf.collectToPython()
|
||||
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
|
||||
sock_info = self._jdf.collectToPython()
|
||||
return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.0)
|
||||
|
@ -477,8 +477,8 @@ class DataFrame(object):
|
|||
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
|
||||
"""
|
||||
with SCCallSiteSync(self._sc) as css:
|
||||
port = self._jdf.toPythonIterator()
|
||||
return _load_from_socket(port, BatchedSerializer(PickleSerializer()))
|
||||
sock_info = self._jdf.toPythonIterator()
|
||||
return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(1.3)
|
||||
|
@ -2087,8 +2087,8 @@ class DataFrame(object):
|
|||
.. note:: Experimental.
|
||||
"""
|
||||
with SCCallSiteSync(self._sc) as css:
|
||||
port = self._jdf.collectAsArrowToPython()
|
||||
return list(_load_from_socket(port, ArrowSerializer()))
|
||||
sock_info = self._jdf.collectAsArrowToPython()
|
||||
return list(_load_from_socket(sock_info, ArrowSerializer()))
|
||||
|
||||
##########################################################################################
|
||||
# Pandas compatibility
|
||||
|
|
|
@ -27,6 +27,7 @@ import traceback
|
|||
|
||||
from pyspark.accumulators import _accumulatorRegistry
|
||||
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
||||
from pyspark.java_gateway import do_server_auth
|
||||
from pyspark.taskcontext import TaskContext
|
||||
from pyspark.files import SparkFiles
|
||||
from pyspark.rdd import PythonEvalType
|
||||
|
@ -301,9 +302,11 @@ def main(infile, outfile):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Read a local port to connect to from stdin
|
||||
java_port = int(sys.stdin.readline())
|
||||
# Read information about how to connect back to the JVM from the environment.
|
||||
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
|
||||
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.connect(("127.0.0.1", java_port))
|
||||
sock_file = sock.makefile("rwb", 65536)
|
||||
do_server_auth(sock_file, auth_secret)
|
||||
main(sock_file, sock_file)
|
||||
|
|
|
@ -201,7 +201,7 @@ try:
|
|||
'pyspark.examples.src.main.python': ['*.py', '*/*.py']},
|
||||
scripts=scripts,
|
||||
license='http://www.apache.org/licenses/LICENSE-2.0',
|
||||
install_requires=['py4j==0.10.6'],
|
||||
install_requires=['py4j==0.10.7'],
|
||||
setup_requires=['pypandoc'],
|
||||
extras_require={
|
||||
'ml': ['numpy>=1.7'],
|
||||
|
|
|
@ -1152,7 +1152,7 @@ private[spark] class Client(
|
|||
val pyArchivesFile = new File(pyLibPath, "pyspark.zip")
|
||||
require(pyArchivesFile.exists(),
|
||||
s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.")
|
||||
val py4jFile = new File(pyLibPath, "py4j-0.10.6-src.zip")
|
||||
val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip")
|
||||
require(py4jFile.exists(),
|
||||
s"$py4jFile not found; cannot run pyspark application in YARN mode.")
|
||||
Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath())
|
||||
|
|
|
@ -265,7 +265,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
|
|||
// needed locations.
|
||||
val sparkHome = sys.props("spark.test.home")
|
||||
val pythonPath = Seq(
|
||||
s"$sparkHome/python/lib/py4j-0.10.6-src.zip",
|
||||
s"$sparkHome/python/lib/py4j-0.10.7-src.zip",
|
||||
s"$sparkHome/python")
|
||||
val extraEnvVars = Map(
|
||||
"PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator),
|
||||
|
|
|
@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}"
|
|||
# Add the PySpark classes to the PYTHONPATH:
|
||||
if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then
|
||||
export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}"
|
||||
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:${PYTHONPATH}"
|
||||
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}"
|
||||
export PYSPARK_PYTHONPATH_SET=1
|
||||
fi
|
||||
|
|
|
@ -3187,7 +3187,7 @@ class Dataset[T] private[sql](
|
|||
EvaluatePython.javaToPython(rdd)
|
||||
}
|
||||
|
||||
private[sql] def collectToPython(): Int = {
|
||||
private[sql] def collectToPython(): Array[Any] = {
|
||||
EvaluatePython.registerPicklers()
|
||||
withAction("collectToPython", queryExecution) { plan =>
|
||||
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
|
||||
|
@ -3200,7 +3200,7 @@ class Dataset[T] private[sql](
|
|||
/**
|
||||
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
|
||||
*/
|
||||
private[sql] def collectAsArrowToPython(): Int = {
|
||||
private[sql] def collectAsArrowToPython(): Array[Any] = {
|
||||
withAction("collectAsArrowToPython", queryExecution) { plan =>
|
||||
val iter: Iterator[Array[Byte]] =
|
||||
toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
|
||||
|
@ -3208,7 +3208,7 @@ class Dataset[T] private[sql](
|
|||
}
|
||||
}
|
||||
|
||||
private[sql] def toPythonIterator(): Int = {
|
||||
private[sql] def toPythonIterator(): Array[Any] = {
|
||||
withNewExecutionId {
|
||||
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue