[SPARK-4897] [PySpark] Python 3 support

This PR update PySpark to support Python 3 (tested with 3.4).

Known issue: unpickle array from Pyrolite is broken in Python 3, those tests are skipped.

TODO: ec2/spark-ec2.py is not fully tested with python3.

Author: Davies Liu <davies@databricks.com>
Author: twneale <twneale@gmail.com>
Author: Josh Rosen <joshrosen@databricks.com>

Closes #5173 from davies/python3 and squashes the following commits:

d7d6323 [Davies Liu] fix tests
6c52a98 [Davies Liu] fix mllib test
99e334f [Davies Liu] update timeout
b716610 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
cafd5ec [Davies Liu] adddress comments from @mengxr
bf225d7 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
179fc8d [Davies Liu] tuning flaky tests
8c8b957 [Davies Liu] fix ResourceWarning in Python 3
5c57c95 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
4006829 [Davies Liu] fix test
2fc0066 [Davies Liu] add python3 path
71535e9 [Davies Liu] fix xrange and divide
5a55ab4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
125f12c [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
ed498c8 [Davies Liu] fix compatibility with python 3
820e649 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
e8ce8c9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
ad7c374 [Davies Liu] fix mllib test and warning
ef1fc2f [Davies Liu] fix tests
4eee14a [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
20112ff [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
59bb492 [Davies Liu] fix tests
1da268c [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
ca0fdd3 [Davies Liu] fix code style
9563a15 [Davies Liu] add imap back for python 2
0b1ec04 [Davies Liu] make python examples work with Python 3
d2fd566 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
a716d34 [Davies Liu] test with python 3.4
f1700e8 [Davies Liu] fix test in python3
671b1db [Davies Liu] fix test in python3
692ff47 [Davies Liu] fix flaky test
7b9699f [Davies Liu] invalidate import cache for Python 3.3+
9c58497 [Davies Liu] fix kill worker
309bfbf [Davies Liu] keep compatibility
5707476 [Davies Liu] cleanup, fix hash of string in 3.3+
8662d5b [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3
f53e1f0 [Davies Liu] fix tests
70b6b73 [Davies Liu] compile ec2/spark_ec2.py in python 3
a39167e [Davies Liu] support customize class in __main__
814c77b [Davies Liu] run unittests with python 3
7f4476e [Davies Liu] mllib tests passed
d737924 [Davies Liu] pass ml tests
375ea17 [Davies Liu] SQL tests pass
6cc42a9 [Davies Liu] rename
431a8de [Davies Liu] streaming tests pass
78901a7 [Davies Liu] fix hash of serializer in Python 3
24b2f2e [Davies Liu] pass all RDD tests
35f48fe [Davies Liu] run future again
1eebac2 [Davies Liu] fix conflict in ec2/spark_ec2.py
6e3c21d [Davies Liu] make cloudpickle work with Python3
2fb2db3 [Josh Rosen] Guard more changes behind sys.version; still doesn't run
1aa5e8f [twneale] Turned out `pickle.DictionaryType is dict` == True, so swapped it out
7354371 [twneale] buffer --> memoryview  I'm not super sure if this a valid change, but the 2.7 docs recommend using memoryview over buffer where possible, so hoping it'll work.
b69ccdf [twneale] Uses the pure python pickle._Pickler instead of c-extension _pickle.Pickler. It appears pyspark 2.7 uses the pure python pickler as well, so this shouldn't degrade pickling performance (?).
f40d925 [twneale] xrange --> range
e104215 [twneale] Replaces 2.7 types.InstsanceType with 3.4 `object`....could be horribly wrong depending on how types.InstanceType is used elsewhere in the package--see http://bugs.python.org/issue8206
79de9d0 [twneale] Replaces python2.7 `file` with 3.4 _io.TextIOWrapper
2adb42d [Josh Rosen] Fix up some import differences between Python 2 and 3
854be27 [Josh Rosen] Run `futurize` on Python code:
7c5b4ce [Josh Rosen] Remove Python 3 check in shell.py.
This commit is contained in:
Davies Liu 2015-04-16 16:20:57 -07:00 committed by Josh Rosen
parent 55f553a979
commit 04e44b37cc
91 changed files with 1401 additions and 1399 deletions

View file

@ -89,6 +89,7 @@ export PYTHONSTARTUP="$SPARK_HOME/python/pyspark/shell.py"
if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
export PYTHONHASHSEED=0
if [[ -n "$PYSPARK_DOC_TEST" ]]; then
exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1
else

View file

@ -19,6 +19,9 @@
SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
# disable randomized hash for string in Python 3.3+
export PYTHONHASHSEED=0
# Only define a usage function if an upstream script hasn't done so.
if ! type -t usage >/dev/null 2>&1; then
usage() {

View file

@ -20,6 +20,9 @@ rem
rem This is the entry point for running Spark submit. To avoid polluting the
rem environment, it just launches a new cmd to do the real work.
rem disable randomized hash for string in Python 3.3+
set PYTHONHASHSEED=0
set CLASS=org.apache.spark.deploy.SparkSubmit
call %~dp0spark-class2.cmd %CLASS% %*
set SPARK_ERROR_LEVEL=%ERRORLEVEL%

View file

@ -235,6 +235,8 @@ echo "========================================================================="
CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS
# add path for python 3 in jenkins
export PATH="${PATH}:/home/anaonda/envs/py3k/bin"
./python/run-tests
echo ""

View file

@ -47,7 +47,7 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}"
# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :(
SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}"
TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout
TESTS_TIMEOUT="150m" # format: http://linux.die.net/man/1/timeout
# Array to capture all tests to run on the pull request. These tests are held under the
#+ dev/tests/ directory.

View file

@ -19,7 +19,7 @@
# limitations under the License.
#
from __future__ import with_statement
from __future__ import with_statement, print_function
import hashlib
import itertools
@ -37,12 +37,17 @@ import tarfile
import tempfile
import textwrap
import time
import urllib2
import warnings
from datetime import datetime
from optparse import OptionParser
from sys import stderr
if sys.version < "3":
from urllib2 import urlopen, Request, HTTPError
else:
from urllib.request import urlopen, Request
from urllib.error import HTTPError
SPARK_EC2_VERSION = "1.2.1"
SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__))
@ -88,10 +93,10 @@ def setup_external_libs(libs):
SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib")
if not os.path.exists(SPARK_EC2_LIB_DIR):
print "Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format(
print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format(
path=SPARK_EC2_LIB_DIR
)
print "This should be a one-time operation."
))
print("This should be a one-time operation.")
os.mkdir(SPARK_EC2_LIB_DIR)
for lib in libs:
@ -100,8 +105,8 @@ def setup_external_libs(libs):
if not os.path.isdir(lib_dir):
tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz")
print " - Downloading {lib}...".format(lib=lib["name"])
download_stream = urllib2.urlopen(
print(" - Downloading {lib}...".format(lib=lib["name"]))
download_stream = urlopen(
"{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format(
prefix=PYPI_URL_PREFIX,
first_letter=lib["name"][:1],
@ -113,13 +118,13 @@ def setup_external_libs(libs):
tgz_file.write(download_stream.read())
with open(tgz_file_path) as tar:
if hashlib.md5(tar.read()).hexdigest() != lib["md5"]:
print >> stderr, "ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"])
print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr)
sys.exit(1)
tar = tarfile.open(tgz_file_path)
tar.extractall(path=SPARK_EC2_LIB_DIR)
tar.close()
os.remove(tgz_file_path)
print " - Finished downloading {lib}.".format(lib=lib["name"])
print(" - Finished downloading {lib}.".format(lib=lib["name"]))
sys.path.insert(1, lib_dir)
@ -299,12 +304,12 @@ def parse_args():
if home_dir is None or not os.path.isfile(home_dir + '/.boto'):
if not os.path.isfile('/etc/boto.cfg'):
if os.getenv('AWS_ACCESS_KEY_ID') is None:
print >> stderr, ("ERROR: The environment variable AWS_ACCESS_KEY_ID " +
"must be set")
print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set",
file=stderr)
sys.exit(1)
if os.getenv('AWS_SECRET_ACCESS_KEY') is None:
print >> stderr, ("ERROR: The environment variable AWS_SECRET_ACCESS_KEY " +
"must be set")
print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set",
file=stderr)
sys.exit(1)
return (opts, action, cluster_name)
@ -316,7 +321,7 @@ def get_or_make_group(conn, name, vpc_id):
if len(group) > 0:
return group[0]
else:
print "Creating security group " + name
print("Creating security group " + name)
return conn.create_security_group(name, "Spark EC2 group", vpc_id)
@ -324,18 +329,19 @@ def get_validate_spark_version(version, repo):
if "." in version:
version = version.replace("v", "")
if version not in VALID_SPARK_VERSIONS:
print >> stderr, "Don't know about Spark version: {v}".format(v=version)
print("Don't know about Spark version: {v}".format(v=version), file=stderr)
sys.exit(1)
return version
else:
github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version)
request = urllib2.Request(github_commit_url)
request = Request(github_commit_url)
request.get_method = lambda: 'HEAD'
try:
response = urllib2.urlopen(request)
except urllib2.HTTPError, e:
print >> stderr, "Couldn't validate Spark commit: {url}".format(url=github_commit_url)
print >> stderr, "Received HTTP response code of {code}.".format(code=e.code)
response = urlopen(request)
except HTTPError as e:
print("Couldn't validate Spark commit: {url}".format(url=github_commit_url),
file=stderr)
print("Received HTTP response code of {code}.".format(code=e.code), file=stderr)
sys.exit(1)
return version
@ -394,8 +400,7 @@ def get_spark_ami(opts):
instance_type = EC2_INSTANCE_TYPES[opts.instance_type]
else:
instance_type = "pvm"
print >> stderr,\
"Don't recognize %s, assuming type is pvm" % opts.instance_type
print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr)
# URL prefix from which to fetch AMI information
ami_prefix = "{r}/{b}/ami-list".format(
@ -404,10 +409,10 @@ def get_spark_ami(opts):
ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type)
try:
ami = urllib2.urlopen(ami_path).read().strip()
print "Spark AMI: " + ami
ami = urlopen(ami_path).read().strip()
print("Spark AMI: " + ami)
except:
print >> stderr, "Could not resolve AMI at: " + ami_path
print("Could not resolve AMI at: " + ami_path, file=stderr)
sys.exit(1)
return ami
@ -419,11 +424,11 @@ def get_spark_ami(opts):
# Fails if there already instances running in the cluster's groups.
def launch_cluster(conn, opts, cluster_name):
if opts.identity_file is None:
print >> stderr, "ERROR: Must provide an identity file (-i) for ssh connections."
print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr)
sys.exit(1)
if opts.key_pair is None:
print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances."
print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr)
sys.exit(1)
user_data_content = None
@ -431,7 +436,7 @@ def launch_cluster(conn, opts, cluster_name):
with open(opts.user_data) as user_data_file:
user_data_content = user_data_file.read()
print "Setting up security groups..."
print("Setting up security groups...")
master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id)
slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id)
authorized_address = opts.authorized_address
@ -497,8 +502,8 @@ def launch_cluster(conn, opts, cluster_name):
existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name,
die_on_error=False)
if existing_slaves or (existing_masters and not opts.use_existing_master):
print >> stderr, ("ERROR: There are already instances running in " +
"group %s or %s" % (master_group.name, slave_group.name))
print("ERROR: There are already instances running in group %s or %s" %
(master_group.name, slave_group.name), file=stderr)
sys.exit(1)
# Figure out Spark AMI
@ -511,12 +516,12 @@ def launch_cluster(conn, opts, cluster_name):
additional_group_ids = [sg.id
for sg in conn.get_all_security_groups()
if opts.additional_security_group in (sg.name, sg.id)]
print "Launching instances..."
print("Launching instances...")
try:
image = conn.get_all_images(image_ids=[opts.ami])[0]
except:
print >> stderr, "Could not find AMI " + opts.ami
print("Could not find AMI " + opts.ami, file=stderr)
sys.exit(1)
# Create block device mapping so that we can add EBS volumes if asked to.
@ -542,8 +547,8 @@ def launch_cluster(conn, opts, cluster_name):
# Launch slaves
if opts.spot_price is not None:
# Launch spot instances with the requested price
print ("Requesting %d slaves as spot instances with price $%.3f" %
(opts.slaves, opts.spot_price))
print("Requesting %d slaves as spot instances with price $%.3f" %
(opts.slaves, opts.spot_price))
zones = get_zones(conn, opts)
num_zones = len(zones)
i = 0
@ -566,7 +571,7 @@ def launch_cluster(conn, opts, cluster_name):
my_req_ids += [req.id for req in slave_reqs]
i += 1
print "Waiting for spot instances to be granted..."
print("Waiting for spot instances to be granted...")
try:
while True:
time.sleep(10)
@ -579,24 +584,24 @@ def launch_cluster(conn, opts, cluster_name):
if i in id_to_req and id_to_req[i].state == "active":
active_instance_ids.append(id_to_req[i].instance_id)
if len(active_instance_ids) == opts.slaves:
print "All %d slaves granted" % opts.slaves
print("All %d slaves granted" % opts.slaves)
reservations = conn.get_all_reservations(active_instance_ids)
slave_nodes = []
for r in reservations:
slave_nodes += r.instances
break
else:
print "%d of %d slaves granted, waiting longer" % (
len(active_instance_ids), opts.slaves)
print("%d of %d slaves granted, waiting longer" % (
len(active_instance_ids), opts.slaves))
except:
print "Canceling spot instance requests"
print("Canceling spot instance requests")
conn.cancel_spot_instance_requests(my_req_ids)
# Log a warning if any of these requests actually launched instances:
(master_nodes, slave_nodes) = get_existing_cluster(
conn, opts, cluster_name, die_on_error=False)
running = len(master_nodes) + len(slave_nodes)
if running:
print >> stderr, ("WARNING: %d instances are still running" % running)
print(("WARNING: %d instances are still running" % running), file=stderr)
sys.exit(0)
else:
# Launch non-spot instances
@ -618,16 +623,16 @@ def launch_cluster(conn, opts, cluster_name):
placement_group=opts.placement_group,
user_data=user_data_content)
slave_nodes += slave_res.instances
print "Launched {s} slave{plural_s} in {z}, regid = {r}".format(
s=num_slaves_this_zone,
plural_s=('' if num_slaves_this_zone == 1 else 's'),
z=zone,
r=slave_res.id)
print("Launched {s} slave{plural_s} in {z}, regid = {r}".format(
s=num_slaves_this_zone,
plural_s=('' if num_slaves_this_zone == 1 else 's'),
z=zone,
r=slave_res.id))
i += 1
# Launch or resume masters
if existing_masters:
print "Starting master..."
print("Starting master...")
for inst in existing_masters:
if inst.state not in ["shutting-down", "terminated"]:
inst.start()
@ -650,10 +655,10 @@ def launch_cluster(conn, opts, cluster_name):
user_data=user_data_content)
master_nodes = master_res.instances
print "Launched master in %s, regid = %s" % (zone, master_res.id)
print("Launched master in %s, regid = %s" % (zone, master_res.id))
# This wait time corresponds to SPARK-4983
print "Waiting for AWS to propagate instance metadata..."
print("Waiting for AWS to propagate instance metadata...")
time.sleep(5)
# Give the instances descriptive names
for master in master_nodes:
@ -674,8 +679,8 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
Get the EC2 instances in an existing cluster if available.
Returns a tuple of lists of EC2 instance objects for the masters and slaves.
"""
print "Searching for existing cluster {c} in region {r}...".format(
c=cluster_name, r=opts.region)
print("Searching for existing cluster {c} in region {r}...".format(
c=cluster_name, r=opts.region))
def get_instances(group_names):
"""
@ -693,16 +698,15 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
slave_instances = get_instances([cluster_name + "-slaves"])
if any((master_instances, slave_instances)):
print "Found {m} master{plural_m}, {s} slave{plural_s}.".format(
m=len(master_instances),
plural_m=('' if len(master_instances) == 1 else 's'),
s=len(slave_instances),
plural_s=('' if len(slave_instances) == 1 else 's'))
print("Found {m} master{plural_m}, {s} slave{plural_s}.".format(
m=len(master_instances),
plural_m=('' if len(master_instances) == 1 else 's'),
s=len(slave_instances),
plural_s=('' if len(slave_instances) == 1 else 's')))
if not master_instances and die_on_error:
print >> sys.stderr, \
"ERROR: Could not find a master for cluster {c} in region {r}.".format(
c=cluster_name, r=opts.region)
print("ERROR: Could not find a master for cluster {c} in region {r}.".format(
c=cluster_name, r=opts.region), file=sys.stderr)
sys.exit(1)
return (master_instances, slave_instances)
@ -713,7 +717,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
master = get_dns_name(master_nodes[0], opts.private_ips)
if deploy_ssh_key:
print "Generating cluster's SSH key on master..."
print("Generating cluster's SSH key on master...")
key_setup = """
[ -f ~/.ssh/id_rsa ] ||
(ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa &&
@ -721,10 +725,10 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
"""
ssh(master, opts, key_setup)
dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh'])
print "Transferring cluster's SSH key to slaves..."
print("Transferring cluster's SSH key to slaves...")
for slave in slave_nodes:
slave_address = get_dns_name(slave, opts.private_ips)
print slave_address
print(slave_address)
ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar)
modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs',
@ -738,8 +742,8 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
# NOTE: We should clone the repository before running deploy_files to
# prevent ec2-variables.sh from being overwritten
print "Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format(
r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)
print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format(
r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch))
ssh(
host=master,
opts=opts,
@ -749,7 +753,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
b=opts.spark_ec2_git_branch)
)
print "Deploying files to master..."
print("Deploying files to master...")
deploy_files(
conn=conn,
root_dir=SPARK_EC2_DIR + "/" + "deploy.generic",
@ -760,25 +764,25 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
)
if opts.deploy_root_dir is not None:
print "Deploying {s} to master...".format(s=opts.deploy_root_dir)
print("Deploying {s} to master...".format(s=opts.deploy_root_dir))
deploy_user_files(
root_dir=opts.deploy_root_dir,
opts=opts,
master_nodes=master_nodes
)
print "Running setup on master..."
print("Running setup on master...")
setup_spark_cluster(master, opts)
print "Done!"
print("Done!")
def setup_spark_cluster(master, opts):
ssh(master, opts, "chmod u+x spark-ec2/setup.sh")
ssh(master, opts, "spark-ec2/setup.sh")
print "Spark standalone cluster started at http://%s:8080" % master
print("Spark standalone cluster started at http://%s:8080" % master)
if opts.ganglia:
print "Ganglia started at http://%s:5080/ganglia" % master
print("Ganglia started at http://%s:5080/ganglia" % master)
def is_ssh_available(host, opts, print_ssh_output=True):
@ -795,7 +799,7 @@ def is_ssh_available(host, opts, print_ssh_output=True):
if s.returncode != 0 and print_ssh_output:
# extra leading newline is for spacing in wait_for_cluster_state()
print textwrap.dedent("""\n
print(textwrap.dedent("""\n
Warning: SSH connection error. (This could be temporary.)
Host: {h}
SSH return code: {r}
@ -804,7 +808,7 @@ def is_ssh_available(host, opts, print_ssh_output=True):
h=host,
r=s.returncode,
o=cmd_output.strip()
)
))
return s.returncode == 0
@ -865,10 +869,10 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state):
sys.stdout.write("\n")
end_time = datetime.now()
print "Cluster is now in '{s}' state. Waited {t} seconds.".format(
print("Cluster is now in '{s}' state. Waited {t} seconds.".format(
s=cluster_state,
t=(end_time - start_time).seconds
)
))
# Get number of local disks available for a given EC2 instance type.
@ -916,8 +920,8 @@ def get_num_disks(instance_type):
if instance_type in disks_by_instance:
return disks_by_instance[instance_type]
else:
print >> stderr, ("WARNING: Don't know number of disks on instance type %s; assuming 1"
% instance_type)
print("WARNING: Don't know number of disks on instance type %s; assuming 1"
% instance_type, file=stderr)
return 1
@ -951,7 +955,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
# Spark-only custom deploy
spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version)
tachyon_v = ""
print "Deploying Spark via git hash; Tachyon won't be set up"
print("Deploying Spark via git hash; Tachyon won't be set up")
modules = filter(lambda x: x != "tachyon", modules)
master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes]
@ -1067,8 +1071,8 @@ def ssh(host, opts, command):
"--key-pair parameters and try again.".format(host))
else:
raise e
print >> stderr, \
"Error executing remote command, retrying after 30 seconds: {0}".format(e)
print("Error executing remote command, retrying after 30 seconds: {0}".format(e),
file=stderr)
time.sleep(30)
tries = tries + 1
@ -1107,8 +1111,8 @@ def ssh_write(host, opts, command, arguments):
elif tries > 5:
raise RuntimeError("ssh_write failed with error %s" % proc.returncode)
else:
print >> stderr, \
"Error {0} while executing remote command, retrying after 30 seconds".format(status)
print("Error {0} while executing remote command, retrying after 30 seconds".
format(status), file=stderr)
time.sleep(30)
tries = tries + 1
@ -1162,42 +1166,41 @@ def real_main():
if opts.identity_file is not None:
if not os.path.exists(opts.identity_file):
print >> stderr,\
"ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file)
print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file),
file=stderr)
sys.exit(1)
file_mode = os.stat(opts.identity_file).st_mode
if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00':
print >> stderr, "ERROR: The identity file must be accessible only by you."
print >> stderr, 'You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file)
print("ERROR: The identity file must be accessible only by you.", file=stderr)
print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file),
file=stderr)
sys.exit(1)
if opts.instance_type not in EC2_INSTANCE_TYPES:
print >> stderr, "Warning: Unrecognized EC2 instance type for instance-type: {t}".format(
t=opts.instance_type)
print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format(
t=opts.instance_type), file=stderr)
if opts.master_instance_type != "":
if opts.master_instance_type not in EC2_INSTANCE_TYPES:
print >> stderr, \
"Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format(
t=opts.master_instance_type)
print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format(
t=opts.master_instance_type), file=stderr)
# Since we try instance types even if we can't resolve them, we check if they resolve first
# and, if they do, see if they resolve to the same virtualization type.
if opts.instance_type in EC2_INSTANCE_TYPES and \
opts.master_instance_type in EC2_INSTANCE_TYPES:
if EC2_INSTANCE_TYPES[opts.instance_type] != \
EC2_INSTANCE_TYPES[opts.master_instance_type]:
print >> stderr, \
"Error: spark-ec2 currently does not support having a master and slaves " + \
"with different AMI virtualization types."
print >> stderr, "master instance virtualization type: {t}".format(
t=EC2_INSTANCE_TYPES[opts.master_instance_type])
print >> stderr, "slave instance virtualization type: {t}".format(
t=EC2_INSTANCE_TYPES[opts.instance_type])
print("Error: spark-ec2 currently does not support having a master and slaves "
"with different AMI virtualization types.", file=stderr)
print("master instance virtualization type: {t}".format(
t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr)
print("slave instance virtualization type: {t}".format(
t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr)
sys.exit(1)
if opts.ebs_vol_num > 8:
print >> stderr, "ebs-vol-num cannot be greater than 8"
print("ebs-vol-num cannot be greater than 8", file=stderr)
sys.exit(1)
# Prevent breaking ami_prefix (/, .git and startswith checks)
@ -1206,23 +1209,22 @@ def real_main():
opts.spark_ec2_git_repo.endswith(".git") or \
not opts.spark_ec2_git_repo.startswith("https://github.com") or \
not opts.spark_ec2_git_repo.endswith("spark-ec2"):
print >> stderr, "spark-ec2-git-repo must be a github repo and it must not have a " \
"trailing / or .git. " \
"Furthermore, we currently only support forks named spark-ec2."
print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. "
"Furthermore, we currently only support forks named spark-ec2.", file=stderr)
sys.exit(1)
if not (opts.deploy_root_dir is None or
(os.path.isabs(opts.deploy_root_dir) and
os.path.isdir(opts.deploy_root_dir) and
os.path.exists(opts.deploy_root_dir))):
print >> stderr, "--deploy-root-dir must be an absolute path to a directory that exists " \
"on the local file system"
print("--deploy-root-dir must be an absolute path to a directory that exists "
"on the local file system", file=stderr)
sys.exit(1)
try:
conn = ec2.connect_to_region(opts.region)
except Exception as e:
print >> stderr, (e)
print((e), file=stderr)
sys.exit(1)
# Select an AZ at random if it was not specified.
@ -1231,7 +1233,7 @@ def real_main():
if action == "launch":
if opts.slaves <= 0:
print >> sys.stderr, "ERROR: You have to start at least 1 slave"
print("ERROR: You have to start at least 1 slave", file=sys.stderr)
sys.exit(1)
if opts.resume:
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
@ -1250,18 +1252,18 @@ def real_main():
conn, opts, cluster_name, die_on_error=False)
if any(master_nodes + slave_nodes):
print "The following instances will be terminated:"
print("The following instances will be terminated:")
for inst in master_nodes + slave_nodes:
print "> %s" % get_dns_name(inst, opts.private_ips)
print "ALL DATA ON ALL NODES WILL BE LOST!!"
print("> %s" % get_dns_name(inst, opts.private_ips))
print("ALL DATA ON ALL NODES WILL BE LOST!!")
msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name)
response = raw_input(msg)
if response == "y":
print "Terminating master..."
print("Terminating master...")
for inst in master_nodes:
inst.terminate()
print "Terminating slaves..."
print("Terminating slaves...")
for inst in slave_nodes:
inst.terminate()
@ -1274,16 +1276,16 @@ def real_main():
cluster_instances=(master_nodes + slave_nodes),
cluster_state='terminated'
)
print "Deleting security groups (this will take some time)..."
print("Deleting security groups (this will take some time)...")
attempt = 1
while attempt <= 3:
print "Attempt %d" % attempt
print("Attempt %d" % attempt)
groups = [g for g in conn.get_all_security_groups() if g.name in group_names]
success = True
# Delete individual rules in all groups before deleting groups to
# remove dependencies between them
for group in groups:
print "Deleting rules in security group " + group.name
print("Deleting rules in security group " + group.name)
for rule in group.rules:
for grant in rule.grants:
success &= group.revoke(ip_protocol=rule.ip_protocol,
@ -1298,10 +1300,10 @@ def real_main():
try:
# It is needed to use group_id to make it work with VPC
conn.delete_security_group(group_id=group.id)
print "Deleted security group %s" % group.name
print("Deleted security group %s" % group.name)
except boto.exception.EC2ResponseError:
success = False
print "Failed to delete security group %s" % group.name
print("Failed to delete security group %s" % group.name)
# Unfortunately, group.revoke() returns True even if a rule was not
# deleted, so this needs to be rerun if something fails
@ -1311,17 +1313,16 @@ def real_main():
attempt += 1
if not success:
print "Failed to delete all security groups after 3 tries."
print "Try re-running in a few minutes."
print("Failed to delete all security groups after 3 tries.")
print("Try re-running in a few minutes.")
elif action == "login":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
if not master_nodes[0].public_dns_name and not opts.private_ips:
print "Master has no public DNS name. Maybe you meant to specify " \
"--private-ips?"
print("Master has no public DNS name. Maybe you meant to specify --private-ips?")
else:
master = get_dns_name(master_nodes[0], opts.private_ips)
print "Logging into master " + master + "..."
print("Logging into master " + master + "...")
proxy_opt = []
if opts.proxy_port is not None:
proxy_opt = ['-D', opts.proxy_port]
@ -1336,19 +1337,18 @@ def real_main():
if response == "y":
(master_nodes, slave_nodes) = get_existing_cluster(
conn, opts, cluster_name, die_on_error=False)
print "Rebooting slaves..."
print("Rebooting slaves...")
for inst in slave_nodes:
if inst.state not in ["shutting-down", "terminated"]:
print "Rebooting " + inst.id
print("Rebooting " + inst.id)
inst.reboot()
elif action == "get-master":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
if not master_nodes[0].public_dns_name and not opts.private_ips:
print "Master has no public DNS name. Maybe you meant to specify " \
"--private-ips?"
print("Master has no public DNS name. Maybe you meant to specify --private-ips?")
else:
print get_dns_name(master_nodes[0], opts.private_ips)
print(get_dns_name(master_nodes[0], opts.private_ips))
elif action == "stop":
response = raw_input(
@ -1361,11 +1361,11 @@ def real_main():
if response == "y":
(master_nodes, slave_nodes) = get_existing_cluster(
conn, opts, cluster_name, die_on_error=False)
print "Stopping master..."
print("Stopping master...")
for inst in master_nodes:
if inst.state not in ["shutting-down", "terminated"]:
inst.stop()
print "Stopping slaves..."
print("Stopping slaves...")
for inst in slave_nodes:
if inst.state not in ["shutting-down", "terminated"]:
if inst.spot_instance_request_id:
@ -1375,11 +1375,11 @@ def real_main():
elif action == "start":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
print "Starting slaves..."
print("Starting slaves...")
for inst in slave_nodes:
if inst.state not in ["shutting-down", "terminated"]:
inst.start()
print "Starting master..."
print("Starting master...")
for inst in master_nodes:
if inst.state not in ["shutting-down", "terminated"]:
inst.start()
@ -1403,15 +1403,15 @@ def real_main():
setup_cluster(conn, master_nodes, slave_nodes, opts, False)
else:
print >> stderr, "Invalid action: %s" % action
print("Invalid action: %s" % action, file=stderr)
sys.exit(1)
def main():
try:
real_main()
except UsageError, e:
print >> stderr, "\nError:\n", e
except UsageError as e:
print("\nError:\n", e, file=stderr)
sys.exit(1)

View file

@ -21,7 +21,8 @@ ALS in pyspark.mllib.recommendation for more conventional use.
This example requires numpy (http://www.numpy.org/)
"""
from os.path import realpath
from __future__ import print_function
import sys
import numpy as np
@ -57,9 +58,9 @@ if __name__ == "__main__":
Usage: als [M] [U] [F] [iterations] [partitions]"
"""
print >> sys.stderr, """WARN: This is a naive implementation of ALS and is given as an
print("""WARN: This is a naive implementation of ALS and is given as an
example. Please use the ALS method found in pyspark.mllib.recommendation for more
conventional use."""
conventional use.""", file=sys.stderr)
sc = SparkContext(appName="PythonALS")
M = int(sys.argv[1]) if len(sys.argv) > 1 else 100
@ -68,8 +69,8 @@ if __name__ == "__main__":
ITERATIONS = int(sys.argv[4]) if len(sys.argv) > 4 else 5
partitions = int(sys.argv[5]) if len(sys.argv) > 5 else 2
print "Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % \
(M, U, F, ITERATIONS, partitions)
print("Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" %
(M, U, F, ITERATIONS, partitions))
R = matrix(rand(M, F)) * matrix(rand(U, F).T)
ms = matrix(rand(M, F))
@ -95,7 +96,7 @@ if __name__ == "__main__":
usb = sc.broadcast(us)
error = rmse(R, ms, us)
print "Iteration %d:" % i
print "\nRMSE: %5.4f\n" % error
print("Iteration %d:" % i)
print("\nRMSE: %5.4f\n" % error)
sc.stop()

View file

@ -15,9 +15,12 @@
# limitations under the License.
#
from __future__ import print_function
import sys
from pyspark import SparkContext
from functools import reduce
"""
Read data file users.avro in local Spark distro:
@ -49,7 +52,7 @@ $ ./bin/spark-submit --driver-class-path /path/to/example/jar \
"""
if __name__ == "__main__":
if len(sys.argv) != 2 and len(sys.argv) != 3:
print >> sys.stderr, """
print("""
Usage: avro_inputformat <data_file> [reader_schema_file]
Run with example jar:
@ -57,7 +60,7 @@ if __name__ == "__main__":
/path/to/examples/avro_inputformat.py <data_file> [reader_schema_file]
Assumes you have Avro data stored in <data_file>. Reader schema can be optionally specified
in [reader_schema_file].
"""
""", file=sys.stderr)
exit(-1)
path = sys.argv[1]
@ -77,6 +80,6 @@ if __name__ == "__main__":
conf=conf)
output = avro_rdd.map(lambda x: x[0]).collect()
for k in output:
print k
print(k)
sc.stop()

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
import sys
from pyspark import SparkContext
@ -47,14 +49,14 @@ cqlsh:test> SELECT * FROM users;
"""
if __name__ == "__main__":
if len(sys.argv) != 4:
print >> sys.stderr, """
print("""
Usage: cassandra_inputformat <host> <keyspace> <cf>
Run with example jar:
./bin/spark-submit --driver-class-path /path/to/example/jar \
/path/to/examples/cassandra_inputformat.py <host> <keyspace> <cf>
Assumes you have some data in Cassandra already, running on <host>, in <keyspace> and <cf>
"""
""", file=sys.stderr)
exit(-1)
host = sys.argv[1]
@ -77,6 +79,6 @@ if __name__ == "__main__":
conf=conf)
output = cass_rdd.collect()
for (k, v) in output:
print (k, v)
print((k, v))
sc.stop()

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
import sys
from pyspark import SparkContext
@ -46,7 +48,7 @@ cqlsh:test> SELECT * FROM users;
"""
if __name__ == "__main__":
if len(sys.argv) != 7:
print >> sys.stderr, """
print("""
Usage: cassandra_outputformat <host> <keyspace> <cf> <user_id> <fname> <lname>
Run with example jar:
@ -60,7 +62,7 @@ if __name__ == "__main__":
... fname text,
... lname text
... );
"""
""", file=sys.stderr)
exit(-1)
host = sys.argv[1]

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
import sys
from pyspark import SparkContext
@ -47,14 +49,14 @@ ROW COLUMN+CELL
"""
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, """
print("""
Usage: hbase_inputformat <host> <table>
Run with example jar:
./bin/spark-submit --driver-class-path /path/to/example/jar \
/path/to/examples/hbase_inputformat.py <host> <table>
Assumes you have some data in HBase already, running on <host>, in <table>
"""
""", file=sys.stderr)
exit(-1)
host = sys.argv[1]
@ -74,6 +76,6 @@ if __name__ == "__main__":
conf=conf)
output = hbase_rdd.collect()
for (k, v) in output:
print (k, v)
print((k, v))
sc.stop()

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
import sys
from pyspark import SparkContext
@ -40,7 +42,7 @@ ROW COLUMN+CELL
"""
if __name__ == "__main__":
if len(sys.argv) != 7:
print >> sys.stderr, """
print("""
Usage: hbase_outputformat <host> <table> <row> <family> <qualifier> <value>
Run with example jar:
@ -48,7 +50,7 @@ if __name__ == "__main__":
/path/to/examples/hbase_outputformat.py <args>
Assumes you have created <table> with column family <family> in HBase
running on <host> already
"""
""", file=sys.stderr)
exit(-1)
host = sys.argv[1]

View file

@ -22,6 +22,7 @@ examples/src/main/python/mllib/kmeans.py.
This example requires NumPy (http://www.numpy.org/).
"""
from __future__ import print_function
import sys
@ -47,12 +48,12 @@ def closestPoint(p, centers):
if __name__ == "__main__":
if len(sys.argv) != 4:
print >> sys.stderr, "Usage: kmeans <file> <k> <convergeDist>"
print("Usage: kmeans <file> <k> <convergeDist>", file=sys.stderr)
exit(-1)
print >> sys.stderr, """WARN: This is a naive implementation of KMeans Clustering and is given
print("""WARN: This is a naive implementation of KMeans Clustering and is given
as an example! Please refer to examples/src/main/python/mllib/kmeans.py for an example on
how to use MLlib's KMeans implementation."""
how to use MLlib's KMeans implementation.""", file=sys.stderr)
sc = SparkContext(appName="PythonKMeans")
lines = sc.textFile(sys.argv[1])
@ -69,13 +70,13 @@ if __name__ == "__main__":
pointStats = closest.reduceByKey(
lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2))
newPoints = pointStats.map(
lambda (x, (y, z)): (x, y / z)).collect()
lambda xy: (xy[0], xy[1][0] / xy[1][1])).collect()
tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints)
for (x, y) in newPoints:
kPoints[x] = y
print "Final centers: " + str(kPoints)
print("Final centers: " + str(kPoints))
sc.stop()

View file

@ -22,10 +22,8 @@ to act on batches of input data using efficient matrix operations.
In practice, one may prefer to use the LogisticRegression algorithm in
MLlib, as shown in examples/src/main/python/mllib/logistic_regression.py.
"""
from __future__ import print_function
from collections import namedtuple
from math import exp
from os.path import realpath
import sys
import numpy as np
@ -42,19 +40,19 @@ D = 10 # Number of dimensions
def readPointBatch(iterator):
strs = list(iterator)
matrix = np.zeros((len(strs), D + 1))
for i in xrange(len(strs)):
matrix[i] = np.fromstring(strs[i].replace(',', ' '), dtype=np.float32, sep=' ')
for i, s in enumerate(strs):
matrix[i] = np.fromstring(s.replace(',', ' '), dtype=np.float32, sep=' ')
return [matrix]
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: logistic_regression <file> <iterations>"
print("Usage: logistic_regression <file> <iterations>", file=sys.stderr)
exit(-1)
print >> sys.stderr, """WARN: This is a naive implementation of Logistic Regression and is
print("""WARN: This is a naive implementation of Logistic Regression and is
given as an example! Please refer to examples/src/main/python/mllib/logistic_regression.py
to see how MLlib's implementation is used."""
to see how MLlib's implementation is used.""", file=sys.stderr)
sc = SparkContext(appName="PythonLR")
points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache()
@ -62,7 +60,7 @@ if __name__ == "__main__":
# Initialize w to a random value
w = 2 * np.random.ranf(size=D) - 1
print "Initial w: " + str(w)
print("Initial w: " + str(w))
# Compute logistic regression gradient for a matrix of data points
def gradient(matrix, w):
@ -76,9 +74,9 @@ if __name__ == "__main__":
return x
for i in range(iterations):
print "On iteration %i" % (i + 1)
print("On iteration %i" % (i + 1))
w -= points.map(lambda m: gradient(m, w)).reduce(add)
print "Final w: " + str(w)
print("Final w: " + str(w))
sc.stop()

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
from pyspark import SparkContext
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
@ -37,10 +39,10 @@ if __name__ == "__main__":
# Prepare training documents, which are labeled.
LabeledDocument = Row("id", "text", "label")
training = sc.parallelize([(0L, "a b c d e spark", 1.0),
(1L, "b d", 0.0),
(2L, "spark f g h", 1.0),
(3L, "hadoop mapreduce", 0.0)]) \
training = sc.parallelize([(0, "a b c d e spark", 1.0),
(1, "b d", 0.0),
(2, "spark f g h", 1.0),
(3, "hadoop mapreduce", 0.0)]) \
.map(lambda x: LabeledDocument(*x)).toDF()
# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
@ -54,16 +56,16 @@ if __name__ == "__main__":
# Prepare test documents, which are unlabeled.
Document = Row("id", "text")
test = sc.parallelize([(4L, "spark i j k"),
(5L, "l m n"),
(6L, "mapreduce spark"),
(7L, "apache hadoop")]) \
test = sc.parallelize([(4, "spark i j k"),
(5, "l m n"),
(6, "mapreduce spark"),
(7, "apache hadoop")]) \
.map(lambda x: Document(*x)).toDF()
# Make predictions on test documents and print columns of interest.
prediction = model.transform(test)
selected = prediction.select("id", "text", "prediction")
for row in selected.collect():
print row
print(row)
sc.stop()

View file

@ -18,6 +18,7 @@
"""
Correlations using MLlib.
"""
from __future__ import print_function
import sys
@ -29,7 +30,7 @@ from pyspark.mllib.util import MLUtils
if __name__ == "__main__":
if len(sys.argv) not in [1, 2]:
print >> sys.stderr, "Usage: correlations (<file>)"
print("Usage: correlations (<file>)", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonCorrelations")
if len(sys.argv) == 2:
@ -41,20 +42,20 @@ if __name__ == "__main__":
points = MLUtils.loadLibSVMFile(sc, filepath)\
.map(lambda lp: LabeledPoint(lp.label, lp.features.toArray()))
print
print 'Summary of data file: ' + filepath
print '%d data points' % points.count()
print()
print('Summary of data file: ' + filepath)
print('%d data points' % points.count())
# Statistics (correlations)
print
print 'Correlation (%s) between label and each feature' % corrType
print 'Feature\tCorrelation'
print()
print('Correlation (%s) between label and each feature' % corrType)
print('Feature\tCorrelation')
numFeatures = points.take(1)[0].features.size
labelRDD = points.map(lambda lp: lp.label)
for i in range(numFeatures):
featureRDD = points.map(lambda lp: lp.features[i])
corr = Statistics.corr(labelRDD, featureRDD, corrType)
print '%d\t%g' % (i, corr)
print
print('%d\t%g' % (i, corr))
print()
sc.stop()

View file

@ -19,6 +19,7 @@
An example of how to use DataFrame as a dataset for ML. Run with::
bin/spark-submit examples/src/main/python/mllib/dataset_example.py
"""
from __future__ import print_function
import os
import sys
@ -32,16 +33,16 @@ from pyspark.mllib.stat import Statistics
def summarize(dataset):
print "schema: %s" % dataset.schema().json()
print("schema: %s" % dataset.schema().json())
labels = dataset.map(lambda r: r.label)
print "label average: %f" % labels.mean()
print("label average: %f" % labels.mean())
features = dataset.map(lambda r: r.features)
summary = Statistics.colStats(features)
print "features average: %r" % summary.mean()
print("features average: %r" % summary.mean())
if __name__ == "__main__":
if len(sys.argv) > 2:
print >> sys.stderr, "Usage: dataset_example.py <libsvm file>"
print("Usage: dataset_example.py <libsvm file>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="DatasetExample")
sqlContext = SQLContext(sc)
@ -54,9 +55,9 @@ if __name__ == "__main__":
summarize(dataset0)
tempdir = tempfile.NamedTemporaryFile(delete=False).name
os.unlink(tempdir)
print "Save dataset as a Parquet file to %s." % tempdir
print("Save dataset as a Parquet file to %s." % tempdir)
dataset0.saveAsParquetFile(tempdir)
print "Load it back and summarize it again."
print("Load it back and summarize it again.")
dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache()
summarize(dataset1)
shutil.rmtree(tempdir)

View file

@ -20,6 +20,7 @@ Decision tree classification and regression using MLlib.
This example requires NumPy (http://www.numpy.org/).
"""
from __future__ import print_function
import numpy
import os
@ -83,18 +84,17 @@ def reindexClassLabels(data):
numClasses = len(classCounts)
# origToNewLabels: class --> index in 0,...,numClasses-1
if (numClasses < 2):
print >> sys.stderr, \
"Dataset for classification should have at least 2 classes." + \
" The given dataset had only %d classes." % numClasses
print("Dataset for classification should have at least 2 classes."
" The given dataset had only %d classes." % numClasses, file=sys.stderr)
exit(1)
origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)])
print "numClasses = %d" % numClasses
print "Per-class example fractions, counts:"
print "Class\tFrac\tCount"
print("numClasses = %d" % numClasses)
print("Per-class example fractions, counts:")
print("Class\tFrac\tCount")
for c in sortedClasses:
frac = classCounts[c] / (numExamples + 0.0)
print "%g\t%g\t%d" % (c, frac, classCounts[c])
print("%g\t%g\t%d" % (c, frac, classCounts[c]))
if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1):
return (data, origToNewLabels)
@ -105,8 +105,7 @@ def reindexClassLabels(data):
def usage():
print >> sys.stderr, \
"Usage: decision_tree_runner [libsvm format data filepath]"
print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr)
exit(1)
@ -133,13 +132,13 @@ if __name__ == "__main__":
model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses,
categoricalFeaturesInfo=categoricalFeaturesInfo)
# Print learned tree and stats.
print "Trained DecisionTree for classification:"
print " Model numNodes: %d" % model.numNodes()
print " Model depth: %d" % model.depth()
print " Training accuracy: %g" % getAccuracy(model, reindexedData)
print("Trained DecisionTree for classification:")
print(" Model numNodes: %d" % model.numNodes())
print(" Model depth: %d" % model.depth())
print(" Training accuracy: %g" % getAccuracy(model, reindexedData))
if model.numNodes() < 20:
print model.toDebugString()
print(model.toDebugString())
else:
print model
print(model)
sc.stop()

View file

@ -18,7 +18,8 @@
"""
A Gaussian Mixture Model clustering program using MLlib.
"""
import sys
from __future__ import print_function
import random
import argparse
import numpy as np
@ -59,7 +60,7 @@ if __name__ == "__main__":
model = GaussianMixture.train(data, args.k, args.convergenceTol,
args.maxIterations, args.seed)
for i in range(args.k):
print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
"sigma = ", model.gaussians[i].sigma.toArray())
print ("Cluster labels (first 100): ", model.predict(data).take(100))
print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
"sigma = ", model.gaussians[i].sigma.toArray()))
print(("Cluster labels (first 100): ", model.predict(data).take(100)))
sc.stop()

View file

@ -18,6 +18,7 @@
"""
Gradient boosted Trees classification and regression using MLlib.
"""
from __future__ import print_function
import sys
@ -34,7 +35,7 @@ def testClassification(trainingData, testData):
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \
testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \
/ float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification ensemble model:')
@ -49,7 +50,7 @@ def testRegression(trainingData, testData):
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \
testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \
/ float(testData.count())
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression ensemble model:')
@ -58,7 +59,7 @@ def testRegression(trainingData, testData):
if __name__ == "__main__":
if len(sys.argv) > 1:
print >> sys.stderr, "Usage: gradient_boosted_trees"
print("Usage: gradient_boosted_trees", file=sys.stderr)
exit(1)
sc = SparkContext(appName="PythonGradientBoostedTrees")

View file

@ -20,6 +20,7 @@ A K-means clustering program using MLlib.
This example requires NumPy (http://www.numpy.org/).
"""
from __future__ import print_function
import sys
@ -34,12 +35,12 @@ def parseVector(line):
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: kmeans <file> <k>"
print("Usage: kmeans <file> <k>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="KMeans")
lines = sc.textFile(sys.argv[1])
data = lines.map(parseVector)
k = int(sys.argv[2])
model = KMeans.train(data, k)
print "Final centers: " + str(model.clusterCenters)
print("Final centers: " + str(model.clusterCenters))
sc.stop()

View file

@ -20,11 +20,10 @@ Logistic regression using MLlib.
This example requires NumPy (http://www.numpy.org/).
"""
from __future__ import print_function
from math import exp
import sys
import numpy as np
from pyspark import SparkContext
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.classification import LogisticRegressionWithSGD
@ -42,12 +41,12 @@ def parsePoint(line):
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: logistic_regression <file> <iterations>"
print("Usage: logistic_regression <file> <iterations>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonLR")
points = sc.textFile(sys.argv[1]).map(parsePoint)
iterations = int(sys.argv[2])
model = LogisticRegressionWithSGD.train(points, iterations)
print "Final weights: " + str(model.weights)
print "Final intercept: " + str(model.intercept)
print("Final weights: " + str(model.weights))
print("Final intercept: " + str(model.intercept))
sc.stop()

View file

@ -22,6 +22,7 @@ Note: This example illustrates binary classification.
For information on multiclass classification, please refer to the decision_tree_runner.py
example.
"""
from __future__ import print_function
import sys
@ -43,7 +44,7 @@ def testClassification(trainingData, testData):
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()\
testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\
/ float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification forest model:')
@ -62,8 +63,8 @@ def testRegression(trainingData, testData):
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()\
/ float(testData.count())
testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\
.sum() / float(testData.count())
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression forest model:')
print(model.toDebugString())
@ -71,7 +72,7 @@ def testRegression(trainingData, testData):
if __name__ == "__main__":
if len(sys.argv) > 1:
print >> sys.stderr, "Usage: random_forest_example"
print("Usage: random_forest_example", file=sys.stderr)
exit(1)
sc = SparkContext(appName="PythonRandomForestExample")

View file

@ -18,6 +18,7 @@
"""
Randomly generated RDDs.
"""
from __future__ import print_function
import sys
@ -27,7 +28,7 @@ from pyspark.mllib.random import RandomRDDs
if __name__ == "__main__":
if len(sys.argv) not in [1, 2]:
print >> sys.stderr, "Usage: random_rdd_generation"
print("Usage: random_rdd_generation", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonRandomRDDGeneration")
@ -37,19 +38,19 @@ if __name__ == "__main__":
# Example: RandomRDDs.normalRDD
normalRDD = RandomRDDs.normalRDD(sc, numExamples)
print 'Generated RDD of %d examples sampled from the standard normal distribution'\
% normalRDD.count()
print ' First 5 samples:'
print('Generated RDD of %d examples sampled from the standard normal distribution'
% normalRDD.count())
print(' First 5 samples:')
for sample in normalRDD.take(5):
print ' ' + str(sample)
print
print(' ' + str(sample))
print()
# Example: RandomRDDs.normalVectorRDD
normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows=numExamples, numCols=2)
print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count()
print ' First 5 samples:'
print('Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count())
print(' First 5 samples:')
for sample in normalVectorRDD.take(5):
print ' ' + str(sample)
print
print(' ' + str(sample))
print()
sc.stop()

View file

@ -18,6 +18,7 @@
"""
Randomly sampled RDDs.
"""
from __future__ import print_function
import sys
@ -27,7 +28,7 @@ from pyspark.mllib.util import MLUtils
if __name__ == "__main__":
if len(sys.argv) not in [1, 2]:
print >> sys.stderr, "Usage: sampled_rdds <libsvm data file>"
print("Usage: sampled_rdds <libsvm data file>", file=sys.stderr)
exit(-1)
if len(sys.argv) == 2:
datapath = sys.argv[1]
@ -41,24 +42,24 @@ if __name__ == "__main__":
examples = MLUtils.loadLibSVMFile(sc, datapath)
numExamples = examples.count()
if numExamples == 0:
print >> sys.stderr, "Error: Data file had no samples to load."
print("Error: Data file had no samples to load.", file=sys.stderr)
exit(1)
print 'Loaded data with %d examples from file: %s' % (numExamples, datapath)
print('Loaded data with %d examples from file: %s' % (numExamples, datapath))
# Example: RDD.sample() and RDD.takeSample()
expectedSampleSize = int(numExamples * fraction)
print 'Sampling RDD using fraction %g. Expected sample size = %d.' \
% (fraction, expectedSampleSize)
print('Sampling RDD using fraction %g. Expected sample size = %d.'
% (fraction, expectedSampleSize))
sampledRDD = examples.sample(withReplacement=True, fraction=fraction)
print ' RDD.sample(): sample has %d examples' % sampledRDD.count()
print(' RDD.sample(): sample has %d examples' % sampledRDD.count())
sampledArray = examples.takeSample(withReplacement=True, num=expectedSampleSize)
print ' RDD.takeSample(): sample has %d examples' % len(sampledArray)
print(' RDD.takeSample(): sample has %d examples' % len(sampledArray))
print
print()
# Example: RDD.sampleByKey()
keyedRDD = examples.map(lambda lp: (int(lp.label), lp.features))
print ' Keyed data using label (Int) as key ==> Orig'
print(' Keyed data using label (Int) as key ==> Orig')
# Count examples per label in original data.
keyCountsA = keyedRDD.countByKey()
@ -69,18 +70,18 @@ if __name__ == "__main__":
sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement=True, fractions=fractions)
keyCountsB = sampledByKeyRDD.countByKey()
sizeB = sum(keyCountsB.values())
print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \
% sizeB
print(' Sampled %d examples using approximate stratified sampling (by label). ==> Sample'
% sizeB)
# Compare samples
print ' \tFractions of examples with key'
print 'Key\tOrig\tSample'
print(' \tFractions of examples with key')
print('Key\tOrig\tSample')
for k in sorted(keyCountsA.keys()):
fracA = keyCountsA[k] / float(numExamples)
if sizeB != 0:
fracB = keyCountsB.get(k, 0) / float(sizeB)
else:
fracB = 0
print '%d\t%g\t%g' % (k, fracA, fracB)
print('%d\t%g\t%g' % (k, fracA, fracB))
sc.stop()

View file

@ -23,6 +23,7 @@
# grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines
# This was done so that the example can be run in local mode
from __future__ import print_function
import sys
@ -34,7 +35,7 @@ USAGE = ("bin/spark-submit --driver-memory 4g "
if __name__ == "__main__":
if len(sys.argv) < 2:
print USAGE
print(USAGE)
sys.exit("Argument for file not provided")
file_path = sys.argv[1]
sc = SparkContext(appName='Word2Vec')
@ -46,5 +47,5 @@ if __name__ == "__main__":
synonyms = model.findSynonyms('china', 40)
for word, cosine_distance in synonyms:
print "{}: {}".format(word, cosine_distance)
print("{}: {}".format(word, cosine_distance))
sc.stop()

View file

@ -19,6 +19,7 @@
This is an example implementation of PageRank. For more conventional use,
Please refer to PageRank implementation provided by graphx
"""
from __future__ import print_function
import re
import sys
@ -42,11 +43,12 @@ def parseNeighbors(urls):
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: pagerank <file> <iterations>"
print("Usage: pagerank <file> <iterations>", file=sys.stderr)
exit(-1)
print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is
given as an example! Please refer to PageRank implementation provided by graphx"""
print("""WARN: This is a naive implementation of PageRank and is
given as an example! Please refer to PageRank implementation provided by graphx""",
file=sys.stderr)
# Initialize the spark context.
sc = SparkContext(appName="PythonPageRank")
@ -62,19 +64,19 @@ if __name__ == "__main__":
links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache()
# Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one.
ranks = links.map(lambda (url, neighbors): (url, 1.0))
ranks = links.map(lambda url_neighbors: (url_neighbors[0], 1.0))
# Calculates and updates URL ranks continuously using PageRank algorithm.
for iteration in xrange(int(sys.argv[2])):
for iteration in range(int(sys.argv[2])):
# Calculates URL contributions to the rank of other URLs.
contribs = links.join(ranks).flatMap(
lambda (url, (urls, rank)): computeContribs(urls, rank))
lambda url_urls_rank: computeContribs(url_urls_rank[1][0], url_urls_rank[1][1]))
# Re-calculates URL ranks based on neighbor contributions.
ranks = contribs.reduceByKey(add).mapValues(lambda rank: rank * 0.85 + 0.15)
# Collects all URL ranks and dump them to console.
for (link, rank) in ranks.collect():
print "%s has rank: %s." % (link, rank)
print("%s has rank: %s." % (link, rank))
sc.stop()

View file

@ -1,3 +1,4 @@
from __future__ import print_function
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
@ -35,14 +36,14 @@ $ ./bin/spark-submit --driver-class-path /path/to/example/jar \\
"""
if __name__ == "__main__":
if len(sys.argv) != 2:
print >> sys.stderr, """
print("""
Usage: parquet_inputformat.py <data_file>
Run with example jar:
./bin/spark-submit --driver-class-path /path/to/example/jar \\
/path/to/examples/parquet_inputformat.py <data_file>
Assumes you have Parquet data stored in <data_file>.
"""
""", file=sys.stderr)
exit(-1)
path = sys.argv[1]
@ -56,6 +57,6 @@ if __name__ == "__main__":
valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter')
output = parquet_rdd.map(lambda x: x[1]).collect()
for k in output:
print k
print(k)
sc.stop()

View file

@ -1,3 +1,4 @@
from __future__ import print_function
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
@ -35,7 +36,7 @@ if __name__ == "__main__":
y = random() * 2 - 1
return 1 if x ** 2 + y ** 2 < 1 else 0
count = sc.parallelize(xrange(1, n + 1), partitions).map(f).reduce(add)
print "Pi is roughly %f" % (4.0 * count / n)
count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
print("Pi is roughly %f" % (4.0 * count / n))
sc.stop()

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
import sys
from pyspark import SparkContext
@ -22,7 +24,7 @@ from pyspark import SparkContext
if __name__ == "__main__":
if len(sys.argv) != 2:
print >> sys.stderr, "Usage: sort <file>"
print("Usage: sort <file>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonSort")
lines = sc.textFile(sys.argv[1], 1)
@ -33,6 +35,6 @@ if __name__ == "__main__":
# In reality, we wouldn't want to collect all the data to the driver node.
output = sortedCount.collect()
for (num, unitcount) in output:
print num
print(num)
sc.stop()

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
import os
from pyspark import SparkContext
@ -68,6 +70,6 @@ if __name__ == "__main__":
teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
for each in teenagers.collect():
print each[0]
print(each[0])
sc.stop()

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
import time
import threading
import Queue
@ -52,15 +54,15 @@ def main():
ids = status.getJobIdsForGroup()
for id in ids:
job = status.getJobInfo(id)
print "Job", id, "status: ", job.status
print("Job", id, "status: ", job.status)
for sid in job.stageIds:
info = status.getStageInfo(sid)
if info:
print "Stage %d: %d tasks total (%d active, %d complete)" % \
(sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks)
print("Stage %d: %d tasks total (%d active, %d complete)" %
(sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks))
time.sleep(1)
print "Job results are:", result.get()
print("Job results are:", result.get())
sc.stop()
if __name__ == "__main__":

View file

@ -25,6 +25,7 @@
Then create a text file in `localdir` and the words in the file will get counted.
"""
from __future__ import print_function
import sys
@ -33,7 +34,7 @@ from pyspark.streaming import StreamingContext
if __name__ == "__main__":
if len(sys.argv) != 2:
print >> sys.stderr, "Usage: hdfs_wordcount.py <directory>"
print("Usage: hdfs_wordcount.py <directory>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonStreamingHDFSWordCount")

View file

@ -27,6 +27,7 @@
spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \
localhost:2181 test`
"""
from __future__ import print_function
import sys
@ -36,7 +37,7 @@ from pyspark.streaming.kafka import KafkaUtils
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: kafka_wordcount.py <zk> <topic>"
print("Usage: kafka_wordcount.py <zk> <topic>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonStreamingKafkaWordCount")

View file

@ -25,6 +25,7 @@
and then run the example
`$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999`
"""
from __future__ import print_function
import sys
@ -33,7 +34,7 @@ from pyspark.streaming import StreamingContext
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: network_wordcount.py <hostname> <port>"
print("Usage: network_wordcount.py <hostname> <port>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonStreamingNetworkWordCount")
ssc = StreamingContext(sc, 1)

View file

@ -35,6 +35,7 @@
checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from
the checkpoint data.
"""
from __future__ import print_function
import os
import sys
@ -46,7 +47,7 @@ from pyspark.streaming import StreamingContext
def createContext(host, port, outputPath):
# If you do not see this printed, that means the StreamingContext has been loaded
# from the new checkpoint
print "Creating new context"
print("Creating new context")
if os.path.exists(outputPath):
os.remove(outputPath)
sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount")
@ -60,8 +61,8 @@ def createContext(host, port, outputPath):
def echo(time, rdd):
counts = "Counts at time %s %s" % (time, rdd.collect())
print counts
print "Appending to " + os.path.abspath(outputPath)
print(counts)
print("Appending to " + os.path.abspath(outputPath))
with open(outputPath, 'a') as f:
f.write(counts + "\n")
@ -70,8 +71,8 @@ def createContext(host, port, outputPath):
if __name__ == "__main__":
if len(sys.argv) != 5:
print >> sys.stderr, "Usage: recoverable_network_wordcount.py <hostname> <port> "\
"<checkpoint-directory> <output-file>"
print("Usage: recoverable_network_wordcount.py <hostname> <port> "
"<checkpoint-directory> <output-file>", file=sys.stderr)
exit(-1)
host, port, checkpoint, output = sys.argv[1:]
ssc = StreamingContext.getOrCreate(checkpoint,

View file

@ -27,6 +27,7 @@
and then run the example
`$ bin/spark-submit examples/src/main/python/streaming/sql_network_wordcount.py localhost 9999`
"""
from __future__ import print_function
import os
import sys
@ -44,7 +45,7 @@ def getSqlContextInstance(sparkContext):
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: sql_network_wordcount.py <hostname> <port> "
print("Usage: sql_network_wordcount.py <hostname> <port> ", file=sys.stderr)
exit(-1)
host, port = sys.argv[1:]
sc = SparkContext(appName="PythonSqlNetworkWordCount")
@ -57,7 +58,7 @@ if __name__ == "__main__":
# Convert RDDs of the words DStream to DataFrame and run SQL query
def process(time, rdd):
print "========= %s =========" % str(time)
print("========= %s =========" % str(time))
try:
# Get the singleton instance of SQLContext

View file

@ -29,6 +29,7 @@
`$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \
localhost 9999`
"""
from __future__ import print_function
import sys
@ -37,7 +38,7 @@ from pyspark.streaming import StreamingContext
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: stateful_network_wordcount.py <hostname> <port>"
print("Usage: stateful_network_wordcount.py <hostname> <port>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount")
ssc = StreamingContext(sc, 1)

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
import sys
from random import Random
@ -49,20 +51,20 @@ if __name__ == "__main__":
# the graph to obtain the path (x, z).
# Because join() joins on keys, the edges are stored in reversed order.
edges = tc.map(lambda (x, y): (y, x))
edges = tc.map(lambda x_y: (x_y[1], x_y[0]))
oldCount = 0L
oldCount = 0
nextCount = tc.count()
while True:
oldCount = nextCount
# Perform the join, obtaining an RDD of (y, (z, x)) pairs,
# then project the result to obtain the new (x, z) paths.
new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a))
new_edges = tc.join(edges).map(lambda __a_b: (__a_b[1][1], __a_b[1][0]))
tc = tc.union(new_edges).distinct().cache()
nextCount = tc.count()
if nextCount == oldCount:
break
print "TC has %i edges" % tc.count()
print("TC has %i edges" % tc.count())
sc.stop()

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
import sys
from operator import add
@ -23,7 +25,7 @@ from pyspark import SparkContext
if __name__ == "__main__":
if len(sys.argv) != 2:
print >> sys.stderr, "Usage: wordcount <file>"
print("Usage: wordcount <file>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonWordCount")
lines = sc.textFile(sys.argv[1], 1)
@ -32,6 +34,6 @@ if __name__ == "__main__":
.reduceByKey(add)
output = counts.collect()
for (word, count) in output:
print "%s: %i" % (word, count)
print("%s: %i" % (word, count))
sc.stop()

View file

@ -18,6 +18,7 @@
package org.apache.spark.mllib.api.python
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating}
import org.apache.spark.rdd.RDD
@ -31,10 +32,14 @@ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorization
predict(SerDe.asTupleRDD(userAndProducts.rdd))
def getUserFeatures: RDD[Array[Any]] = {
SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
SerDe.fromTuple2RDD(userFeatures.map {
case (user, feature) => (user, Vectors.dense(feature))
}.asInstanceOf[RDD[(Any, Any)]])
}
def getProductFeatures: RDD[Array[Any]] = {
SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
SerDe.fromTuple2RDD(productFeatures.map {
case (product, feature) => (product, Vectors.dense(feature))
}.asInstanceOf[RDD[(Any, Any)]])
}
}

View file

@ -28,7 +28,6 @@ import scala.reflect.ClassTag
import net.razorvine.pickle._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
@ -40,15 +39,15 @@ import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.stat.test.ChiSqTestResult
import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy}
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.loss.Losses
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel}
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@ -279,7 +278,7 @@ private[python] class PythonMLLibAPI extends Serializable {
data: JavaRDD[LabeledPoint],
lambda: Double): JList[Object] = {
val model = NaiveBayes.train(data.rdd, lambda)
List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta).
List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta.map(Vectors.dense)).
map(_.asInstanceOf[Object]).asJava
}
@ -335,7 +334,7 @@ private[python] class PythonMLLibAPI extends Serializable {
mu += model.gaussians(i).mu
sigma += model.gaussians(i).sigma
}
List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
} finally {
data.rdd.unpersist(blocking = false)
}
@ -346,20 +345,20 @@ private[python] class PythonMLLibAPI extends Serializable {
*/
def predictSoftGMM(
data: JavaRDD[Vector],
wt: Object,
wt: Vector,
mu: Array[Object],
si: Array[Object]): RDD[Array[Double]] = {
si: Array[Object]): RDD[Vector] = {
val weight = wt.asInstanceOf[Array[Double]]
val weight = wt.toArray
val mean = mu.map(_.asInstanceOf[DenseVector])
val sigma = si.map(_.asInstanceOf[DenseMatrix])
val gaussians = Array.tabulate(weight.length){
i => new MultivariateGaussian(mean(i), sigma(i))
}
val model = new GaussianMixtureModel(weight, gaussians)
model.predictSoft(data)
model.predictSoft(data).map(Vectors.dense)
}
/**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
@ -936,6 +935,14 @@ private[spark] object SerDe extends Serializable {
out.write(code)
}
protected def getBytes(obj: Object): Array[Byte] = {
if (obj.getClass.isArray) {
obj.asInstanceOf[Array[Byte]]
} else {
obj.asInstanceOf[String].getBytes(LATIN1)
}
}
private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler)
}
@ -961,7 +968,7 @@ private[spark] object SerDe extends Serializable {
if (args.length != 1) {
throw new PickleException("should be 1")
}
val bytes = args(0).asInstanceOf[String].getBytes(LATIN1)
val bytes = getBytes(args(0))
val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
bb.order(ByteOrder.nativeOrder())
val db = bb.asDoubleBuffer()
@ -994,7 +1001,7 @@ private[spark] object SerDe extends Serializable {
if (args.length != 3) {
throw new PickleException("should be 3")
}
val bytes = args(2).asInstanceOf[String].getBytes(LATIN1)
val bytes = getBytes(args(2))
val n = bytes.length / 8
val values = new Array[Double](n)
val order = ByteOrder.nativeOrder()
@ -1031,8 +1038,8 @@ private[spark] object SerDe extends Serializable {
throw new PickleException("should be 3")
}
val size = args(0).asInstanceOf[Int]
val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1)
val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1)
val indiceBytes = getBytes(args(1))
val valueBytes = getBytes(args(2))
val n = indiceBytes.length / 4
val indices = new Array[Int](n)
val values = new Array[Double](n)

View file

@ -54,7 +54,7 @@
... def zero(self, value):
... return [0.0] * len(value)
... def addInPlace(self, val1, val2):
... for i in xrange(len(val1)):
... for i in range(len(val1)):
... val1[i] += val2[i]
... return val1
>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
@ -86,9 +86,13 @@ Traceback (most recent call last):
Exception:...
"""
import sys
import select
import struct
import SocketServer
if sys.version < '3':
import SocketServer
else:
import socketserver as SocketServer
import threading
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import read_int, PickleSerializer
@ -247,6 +251,7 @@ class AccumulatorServer(SocketServer.TCPServer):
def shutdown(self):
self.server_shutdown = True
SocketServer.TCPServer.shutdown(self)
self.server_close()
def _start_update_server():

View file

@ -16,10 +16,15 @@
#
import os
import cPickle
import sys
import gc
from tempfile import NamedTemporaryFile
if sys.version < '3':
import cPickle as pickle
else:
import pickle
unicode = str
__all__ = ['Broadcast']
@ -70,33 +75,19 @@ class Broadcast(object):
self._path = path
def dump(self, value, f):
if isinstance(value, basestring):
if isinstance(value, unicode):
f.write('U')
value = value.encode('utf8')
else:
f.write('S')
f.write(value)
else:
f.write('P')
cPickle.dump(value, f, 2)
pickle.dump(value, f, 2)
f.close()
return f.name
def load(self, path):
with open(path, 'rb', 1 << 20) as f:
flag = f.read(1)
data = f.read()
if flag == 'P':
# cPickle.loads() may create lots of objects, disable GC
# temporary for better performance
gc.disable()
try:
return cPickle.loads(data)
finally:
gc.enable()
else:
return data.decode('utf8') if flag == 'U' else data
# pickle.load() may create lots of objects, disable GC
# temporary for better performance
gc.disable()
try:
return pickle.load(f)
finally:
gc.enable()
@property
def value(self):

View file

@ -40,164 +40,126 @@ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from __future__ import print_function
import operator
import os
import io
import pickle
import struct
import sys
import types
from functools import partial
import itertools
from copy_reg import _extension_registry, _inverted_registry, _extension_cache
import new
import dis
import traceback
import platform
PyImp = platform.python_implementation()
import logging
cloudLog = logging.getLogger("Cloud.Transport")
if sys.version < '3':
from pickle import Pickler
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
PY3 = False
else:
types.ClassType = type
from pickle import _Pickler as Pickler
from io import BytesIO as StringIO
PY3 = True
#relevant opcodes
STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL'))
DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL'))
LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL'))
STORE_GLOBAL = dis.opname.index('STORE_GLOBAL')
DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL')
LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL')
GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
HAVE_ARGUMENT = dis.HAVE_ARGUMENT
EXTENDED_ARG = dis.EXTENDED_ARG
HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT)
EXTENDED_ARG = chr(dis.EXTENDED_ARG)
if PyImp == "PyPy":
# register builtin type in `new`
new.method = types.MethodType
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
# These helper functions were copied from PiCloud's util module.
def islambda(func):
return getattr(func,'func_name') == '<lambda>'
def xrange_params(xrangeobj):
"""Returns a 3 element tuple describing the xrange start, step, and len
respectively
Note: Only guarentees that elements of xrange are the same. parameters may
be different.
e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same
though w/ iteration
"""
xrange_len = len(xrangeobj)
if not xrange_len: #empty
return (0,1,0)
start = xrangeobj[0]
if xrange_len == 1: #one element
return start, 1, 1
return (start, xrangeobj[1] - xrangeobj[0], xrange_len)
#debug variables intended for developer use:
printSerialization = False
printMemoization = False
useForcedImports = True #Should I use forced imports for tracking?
return getattr(func,'__name__') == '<lambda>'
_BUILTIN_TYPE_NAMES = {}
for k, v in types.__dict__.items():
if type(v) is type:
_BUILTIN_TYPE_NAMES[v] = k
class CloudPickler(pickle.Pickler):
dispatch = pickle.Pickler.dispatch.copy()
savedForceImports = False
savedDjangoEnv = False #hack tro transport django environment
def _builtin_type(name):
return getattr(types, name)
def __init__(self, file, protocol=None, min_size_to_save= 0):
pickle.Pickler.__init__(self,file,protocol)
self.modules = set() #set of modules needed to depickle
self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env
class CloudPickler(Pickler):
dispatch = Pickler.dispatch.copy()
def __init__(self, file, protocol=None):
Pickler.__init__(self, file, protocol)
# set of modules to unpickle
self.modules = set()
# map ids to dictionary. used to ensure that functions can share global env
self.globals_ref = {}
def dump(self, obj):
# note: not thread safe
# minimal side-effects, so not fixing
recurse_limit = 3000
base_recurse = sys.getrecursionlimit()
if base_recurse < recurse_limit:
sys.setrecursionlimit(recurse_limit)
self.inject_addons()
try:
return pickle.Pickler.dump(self, obj)
except RuntimeError, e:
return Pickler.dump(self, obj)
except RuntimeError as e:
if 'recursion' in e.args[0]:
msg = """Could not pickle object as excessively deep recursion required.
Try _fast_serialization=2 or contact PiCloud support"""
msg = """Could not pickle object as excessively deep recursion required."""
raise pickle.PicklingError(msg)
finally:
new_recurse = sys.getrecursionlimit()
if new_recurse == recurse_limit:
sys.setrecursionlimit(base_recurse)
def save_memoryview(self, obj):
"""Fallback to save_string"""
Pickler.save_string(self, str(obj))
def save_buffer(self, obj):
"""Fallback to save_string"""
pickle.Pickler.save_string(self,str(obj))
dispatch[buffer] = save_buffer
Pickler.save_string(self,str(obj))
if PY3:
dispatch[memoryview] = save_memoryview
else:
dispatch[buffer] = save_buffer
#block broken objects
def save_unsupported(self, obj, pack=None):
def save_unsupported(self, obj):
raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj))
dispatch[types.GeneratorType] = save_unsupported
#python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it
try:
slice(0,1).__reduce__()
except TypeError: #can't pickle -
dispatch[slice] = save_unsupported
#itertools objects do not pickle!
# itertools objects do not pickle!
for v in itertools.__dict__.values():
if type(v) is type:
dispatch[v] = save_unsupported
def save_dict(self, obj):
"""hack fix
If the dict is a global, deal with it in a special way
"""
#print 'saving', obj
if obj is __builtins__:
self.save_reduce(_get_module_builtins, (), obj=obj)
else:
pickle.Pickler.save_dict(self, obj)
dispatch[pickle.DictionaryType] = save_dict
def save_module(self, obj, pack=struct.pack):
def save_module(self, obj):
"""
Save a module as an import
"""
#print 'try save import', obj.__name__
self.modules.add(obj)
self.save_reduce(subimport,(obj.__name__,), obj=obj)
dispatch[types.ModuleType] = save_module #new type
self.save_reduce(subimport, (obj.__name__,), obj=obj)
dispatch[types.ModuleType] = save_module
def save_codeobject(self, obj, pack=struct.pack):
def save_codeobject(self, obj):
"""
Save a code object
"""
#print 'try to save codeobj: ', obj
args = (
obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code,
obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name,
obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars
)
if PY3:
args = (
obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames,
obj.co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
obj.co_cellvars
)
else:
args = (
obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code,
obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name,
obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars
)
self.save_reduce(types.CodeType, args, obj=obj)
dispatch[types.CodeType] = save_codeobject #new type
dispatch[types.CodeType] = save_codeobject
def save_function(self, obj, name=None, pack=struct.pack):
def save_function(self, obj, name=None):
""" Registered with the dispatch to handle all function types.
Determines what kind of function obj is (e.g. lambda, defined at
@ -205,12 +167,14 @@ class CloudPickler(pickle.Pickler):
"""
write = self.write
name = obj.__name__
if name is None:
name = obj.__name__
modname = pickle.whichmodule(obj, name)
#print 'which gives %s %s %s' % (modname, obj, name)
# print('which gives %s %s %s' % (modname, obj, name))
try:
themodule = sys.modules[modname]
except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__
except KeyError:
# eval'd items such as namedtuple give invalid items for their function __module__
modname = '__main__'
if modname == '__main__':
@ -221,37 +185,18 @@ class CloudPickler(pickle.Pickler):
if getattr(themodule, name, None) is obj:
return self.save_global(obj, name)
if not self.savedDjangoEnv:
#hack for django - if we detect the settings module, we transport it
django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '')
if django_settings:
django_mod = sys.modules.get(django_settings)
if django_mod:
cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name)
self.savedDjangoEnv = True
self.modules.add(django_mod)
write(pickle.MARK)
self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod)
write(pickle.POP_MARK)
# if func is lambda, def'ed at prompt, is in main, or is nested, then
# we'll pickle the actual function object rather than simply saving a
# reference (as is done in default pickler), via save_function_tuple.
if islambda(obj) or obj.func_code.co_filename == '<stdin>' or themodule is None:
#Force server to import modules that have been imported in main
modList = None
if themodule is None and not self.savedForceImports:
mainmod = sys.modules['__main__']
if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'):
modList = list(mainmod.___pyc_forcedImports__)
self.savedForceImports = True
self.save_function_tuple(obj, modList)
if islambda(obj) or obj.__code__.co_filename == '<stdin>' or themodule is None:
#print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule)
self.save_function_tuple(obj)
return
else: # func is nested
else:
# func is nested
klass = getattr(themodule, name, None)
if klass is None or klass is not obj:
self.save_function_tuple(obj, [themodule])
self.save_function_tuple(obj)
return
if obj.__dict__:
@ -266,7 +211,7 @@ class CloudPickler(pickle.Pickler):
self.memoize(obj)
dispatch[types.FunctionType] = save_function
def save_function_tuple(self, func, forced_imports):
def save_function_tuple(self, func):
""" Pickles an actual func object.
A func comprises: code, globals, defaults, closure, and dict. We
@ -281,19 +226,6 @@ class CloudPickler(pickle.Pickler):
save = self.save
write = self.write
# save the modules (if any)
if forced_imports:
write(pickle.MARK)
save(_modules_to_main)
#print 'forced imports are', forced_imports
forced_names = map(lambda m: m.__name__, forced_imports)
save((forced_names,))
#save((forced_imports,))
write(pickle.REDUCE)
write(pickle.POP_MARK)
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
save(_fill_function) # skeleton function updater
@ -318,6 +250,8 @@ class CloudPickler(pickle.Pickler):
Find all globals names read or written to by codeblock co
"""
code = co.co_code
if not PY3:
code = [ord(c) for c in code]
names = co.co_names
out_names = set()
@ -327,18 +261,18 @@ class CloudPickler(pickle.Pickler):
while i < n:
op = code[i]
i = i+1
i += 1
if op >= HAVE_ARGUMENT:
oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
oparg = code[i] + code[i+1] * 256 + extended_arg
extended_arg = 0
i = i+2
i += 2
if op == EXTENDED_ARG:
extended_arg = oparg*65536L
extended_arg = oparg*65536
if op in GLOBAL_OPS:
out_names.add(names[oparg])
#print 'extracted', out_names, ' from ', names
if co.co_consts: # see if nested function have any global refs
# see if nested function have any global refs
if co.co_consts:
for const in co.co_consts:
if type(const) is types.CodeType:
out_names |= CloudPickler.extract_code_globals(const)
@ -350,46 +284,28 @@ class CloudPickler(pickle.Pickler):
Turn the function into a tuple of data necessary to recreate it:
code, globals, defaults, closure, dict
"""
code = func.func_code
code = func.__code__
# extract all global ref's
func_global_refs = CloudPickler.extract_code_globals(code)
func_global_refs = self.extract_code_globals(code)
# process all variables referenced by global environment
f_globals = {}
for var in func_global_refs:
#Some names, such as class functions are not global - we don't need them
if func.func_globals.has_key(var):
f_globals[var] = func.func_globals[var]
if var in func.__globals__:
f_globals[var] = func.__globals__[var]
# defaults requires no processing
defaults = func.func_defaults
def get_contents(cell):
try:
return cell.cell_contents
except ValueError, e: #cell is empty error on not yet assigned
raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope')
defaults = func.__defaults__
# process closure
if func.func_closure:
closure = map(get_contents, func.func_closure)
else:
closure = []
closure = [c.cell_contents for c in func.__closure__] if func.__closure__ else []
# save the dict
dct = func.func_dict
dct = func.__dict__
if printSerialization:
outvars = ['code: ' + str(code) ]
outvars.append('globals: ' + str(f_globals))
outvars.append('defaults: ' + str(defaults))
outvars.append('closure: ' + str(closure))
print 'function ', func, 'is extracted to: ', ', '.join(outvars)
base_globals = self.globals_ref.get(id(func.func_globals), {})
self.globals_ref[id(func.func_globals)] = base_globals
base_globals = self.globals_ref.get(id(func.__globals__), {})
self.globals_ref[id(func.__globals__)] = base_globals
return (code, f_globals, defaults, closure, dct, base_globals)
@ -400,8 +316,9 @@ class CloudPickler(pickle.Pickler):
dispatch[types.BuiltinFunctionType] = save_builtin_function
def save_global(self, obj, name=None, pack=struct.pack):
write = self.write
memo = self.memo
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
if obj in _BUILTIN_TYPE_NAMES:
return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
if name is None:
name = obj.__name__
@ -410,98 +327,57 @@ class CloudPickler(pickle.Pickler):
if modname is None:
modname = pickle.whichmodule(obj, name)
try:
__import__(modname)
themodule = sys.modules[modname]
except (ImportError, KeyError, AttributeError): #should never occur
raise pickle.PicklingError(
"Can't pickle %r: Module %s cannot be found" %
(obj, modname))
if modname == '__main__':
themodule = None
if themodule:
else:
__import__(modname)
themodule = sys.modules[modname]
self.modules.add(themodule)
sendRef = True
typ = type(obj)
#print 'saving', obj, typ
try:
try: #Deal with case when getattribute fails with exceptions
klass = getattr(themodule, name)
except (AttributeError):
if modname == '__builtin__': #new.* are misrepeported
modname = 'new'
__import__(modname)
themodule = sys.modules[modname]
try:
klass = getattr(themodule, name)
except AttributeError, a:
# print themodule, name, obj, type(obj)
raise pickle.PicklingError("Can't pickle builtin %s" % obj)
else:
raise
if hasattr(themodule, name) and getattr(themodule, name) is obj:
return Pickler.save_global(self, obj, name)
except (ImportError, KeyError, AttributeError):
if typ == types.TypeType or typ == types.ClassType:
sendRef = False
else: #we can't deal with this
raise
else:
if klass is not obj and (typ == types.TypeType or typ == types.ClassType):
sendRef = False
if not sendRef:
#note: Third party types might crash this - add better checks!
d = dict(obj.__dict__) #copy dict proxy to a dict
if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties
d.pop('__dict__',None)
d.pop('__weakref__',None)
typ = type(obj)
if typ is not obj and isinstance(obj, (type, types.ClassType)):
d = dict(obj.__dict__) # copy dict proxy to a dict
if not isinstance(d.get('__dict__', None), property):
# don't extract dict that are properties
d.pop('__dict__', None)
d.pop('__weakref__', None)
# hack as __new__ is stored differently in the __dict__
new_override = d.get('__new__', None)
if new_override:
d['__new__'] = obj.__new__
self.save_reduce(type(obj),(obj.__name__,obj.__bases__,
d),obj=obj)
#print 'internal reduce dask %s %s' % (obj, d)
return
self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
else:
raise pickle.PicklingError("Can't pickle %r" % obj)
if self.proto >= 2:
code = _extension_registry.get((modname, name))
if code:
assert code > 0
if code <= 0xff:
write(pickle.EXT1 + chr(code))
elif code <= 0xffff:
write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8))
else:
write(pickle.EXT4 + pack("<i", code))
return
write(pickle.GLOBAL + modname + '\n' + name + '\n')
self.memoize(obj)
dispatch[type] = save_global
dispatch[types.ClassType] = save_global
dispatch[types.TypeType] = save_global
def save_instancemethod(self, obj):
#Memoization rarely is ever useful due to python bounding
self.save_reduce(types.MethodType, (obj.im_func, obj.im_self,obj.im_class), obj=obj)
# Memoization rarely is ever useful due to python bounding
if PY3:
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
else:
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__),
obj=obj)
dispatch[types.MethodType] = save_instancemethod
def save_inst_logic(self, obj):
def save_inst(self, obj):
"""Inner logic to save instance. Based off pickle.save_inst
Supports __transient__"""
cls = obj.__class__
memo = self.memo
memo = self.memo
write = self.write
save = self.save
save = self.save
if hasattr(obj, '__getinitargs__'):
args = obj.__getinitargs__()
len(args) # XXX Assert it's a sequence
len(args) # XXX Assert it's a sequence
pickle._keep_alive(args, memo)
else:
args = ()
@ -537,15 +413,8 @@ class CloudPickler(pickle.Pickler):
save(stuff)
write(pickle.BUILD)
def save_inst(self, obj):
# Hack to detect PIL Image instances without importing Imaging
# PIL can be loaded with multiple names, so we don't check sys.modules for it
if hasattr(obj,'im') and hasattr(obj,'palette') and 'Image' in obj.__module__:
self.save_image(obj)
else:
self.save_inst_logic(obj)
dispatch[types.InstanceType] = save_inst
if not PY3:
dispatch[types.InstanceType] = save_inst
def save_property(self, obj):
# properties not correctly saved in python
@ -592,7 +461,7 @@ class CloudPickler(pickle.Pickler):
"""Modified to support __transient__ on new objects
Change only affects protocol level 2 (which is always used by PiCloud"""
# Assert that args is a tuple or None
if not isinstance(args, types.TupleType):
if not isinstance(args, tuple):
raise pickle.PicklingError("args from reduce() should be a tuple")
# Assert that func is callable
@ -646,35 +515,23 @@ class CloudPickler(pickle.Pickler):
self._batch_setitems(dictitems)
if state is not None:
#print 'obj %s has state %s' % (obj, state)
save(state)
write(pickle.BUILD)
def save_xrange(self, obj):
"""Save an xrange object in python 2.5
Python 2.6 supports this natively
"""
range_params = xrange_params(obj)
self.save_reduce(_build_xrange,range_params)
#python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it
try:
xrange(0).__reduce__()
except TypeError: #can't pickle -- use PiCloud pickler
dispatch[xrange] = save_xrange
def save_partial(self, obj):
"""Partial objects do not serialize correctly in python2.x -- this fixes the bugs"""
self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords))
if sys.version_info < (2,7): #2.7 supports partial pickling
if sys.version_info < (2,7): # 2.7 supports partial pickling
dispatch[partial] = save_partial
def save_file(self, obj):
"""Save a file"""
import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute
try:
import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute
except ImportError:
import io as pystringIO
if not hasattr(obj, 'name') or not hasattr(obj, 'mode'):
raise pickle.PicklingError("Cannot pickle files that do not map to an actual file")
@ -720,10 +577,14 @@ class CloudPickler(pickle.Pickler):
retval.seek(curloc)
retval.name = name
self.save(retval) #save stringIO
self.save(retval)
self.memoize(obj)
dispatch[file] = save_file
if PY3:
dispatch[io.TextIOWrapper] = save_file
else:
dispatch[file] = save_file
"""Special functions for Add-on libraries"""
def inject_numpy(self):
@ -732,76 +593,20 @@ class CloudPickler(pickle.Pickler):
return
self.dispatch[numpy.ufunc] = self.__class__.save_ufunc
numpy_tst_mods = ['numpy', 'scipy.special']
def save_ufunc(self, obj):
"""Hack function for saving numpy ufunc objects"""
name = obj.__name__
for tst_mod_name in self.numpy_tst_mods:
numpy_tst_mods = ['numpy', 'scipy.special']
for tst_mod_name in numpy_tst_mods:
tst_mod = sys.modules.get(tst_mod_name, None)
if tst_mod:
if name in tst_mod.__dict__:
self.save_reduce(_getobject, (tst_mod_name, name))
return
raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj))
def inject_timeseries(self):
"""Handle bugs with pickling scikits timeseries"""
tseries = sys.modules.get('scikits.timeseries.tseries')
if not tseries or not hasattr(tseries, 'Timeseries'):
return
self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries
def save_timeseries(self, obj):
import scikits.timeseries.tseries as ts
func, reduce_args, state = obj.__reduce__()
if func != ts._tsreconstruct:
raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func))
state = (1,
obj.shape,
obj.dtype,
obj.flags.fnc,
obj._data.tostring(),
ts.getmaskarray(obj).tostring(),
obj._fill_value,
obj._dates.shape,
obj._dates.__array__().tostring(),
obj._dates.dtype, #added -- preserve type
obj.freq,
obj._optinfo,
)
return self.save_reduce(_genTimeSeries, (reduce_args, state))
def inject_email(self):
"""Block email LazyImporters from being saved"""
email = sys.modules.get('email')
if not email:
return
self.dispatch[email.LazyImporter] = self.__class__.save_unsupported
if tst_mod and name in tst_mod.__dict__:
return self.save_reduce(_getobject, (tst_mod_name, name))
raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in'
% str(obj))
def inject_addons(self):
"""Plug in system. Register additional pickling functions if modules already loaded"""
self.inject_numpy()
self.inject_timeseries()
self.inject_email()
"""Python Imaging Library"""
def save_image(self, obj):
if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \
and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()):
#if image not loaded yet -- lazy load
self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj)
else:
#image is loaded - just transmit it over
self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj)
"""
def memoize(self, obj):
pickle.Pickler.memoize(self, obj)
if printMemoization:
print 'memoizing ' + str(obj)
"""
# Shorthands for legacy support
@ -809,14 +614,13 @@ class CloudPickler(pickle.Pickler):
def dump(obj, file, protocol=2):
CloudPickler(file, protocol).dump(obj)
def dumps(obj, protocol=2):
file = StringIO()
cp = CloudPickler(file,protocol)
cp.dump(obj)
#print 'cloud dumped', str(obj), str(cp.modules)
return file.getvalue()
@ -825,25 +629,6 @@ def subimport(name):
__import__(name)
return sys.modules[name]
#hack to load django settings:
def django_settings_load(name):
modified_env = False
if 'DJANGO_SETTINGS_MODULE' not in os.environ:
os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps
modified_env = True
try:
module = subimport(name)
except Exception, i:
print >> sys.stderr, 'Cloud not import django settings %s:' % (name)
print_exec(sys.stderr)
if modified_env:
del os.environ['DJANGO_SETTINGS_MODULE']
else:
#add project directory to sys,path:
if hasattr(module,'__file__'):
dirname = os.path.split(module.__file__)[0] + '/'
sys.path.append(dirname)
# restores function attributes
def _restore_attr(obj, attr):
@ -851,13 +636,16 @@ def _restore_attr(obj, attr):
setattr(obj, key, val)
return obj
def _get_module_builtins():
return pickle.__builtins__
def print_exec(stream):
ei = sys.exc_info()
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
def _modules_to_main(modList):
"""Force every module in modList to be placed into main"""
if not modList:
@ -868,22 +656,16 @@ def _modules_to_main(modList):
if type(modname) is str:
try:
mod = __import__(modname)
except Exception, i: #catch all...
sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \
A version mismatch is likely. Specific error was:\n' % modname)
except Exception as e:
sys.stderr.write('warning: could not import %s\n. '
'Your function may unexpectedly error due to this import failing;'
'A version mismatch is likely. Specific error was:\n' % modname)
print_exec(sys.stderr)
else:
setattr(main,mod.__name__, mod)
else:
#REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD)
#In old version actual module was sent
setattr(main,modname.__name__, modname)
setattr(main, mod.__name__, mod)
#object generators:
def _build_xrange(start, step, len):
"""Built xrange explicitly"""
return xrange(start, start + step*len, step)
def _genpartial(func, args, kwds):
if not args:
args = ()
@ -891,22 +673,26 @@ def _genpartial(func, args, kwds):
kwds = {}
return partial(func, *args, **kwds)
def _fill_function(func, globals, defaults, dict):
""" Fills in the rest of function data into the skeleton function object
that were created via _make_skel_func().
"""
func.func_globals.update(globals)
func.func_defaults = defaults
func.func_dict = dict
func.__globals__.update(globals)
func.__defaults__ = defaults
func.__dict__ = dict
return func
def _make_cell(value):
return (lambda: value).func_closure[0]
return (lambda: value).__closure__[0]
def _reconstruct_closure(values):
return tuple([_make_cell(v) for v in values])
def _make_skel_func(code, closures, base_globals = None):
""" Creates a skeleton function object that contains just the provided
code and the correct number of cells in func_closure. All other
@ -928,40 +714,3 @@ Note: These can never be renamed due to client compatibility issues"""
def _getobject(modname, attribute):
mod = __import__(modname, fromlist=[attribute])
return mod.__dict__[attribute]
def _generateImage(size, mode, str_rep):
"""Generate image from string representation"""
import Image
i = Image.new(mode, size)
i.fromstring(str_rep)
return i
def _lazyloadImage(fp):
import Image
fp.seek(0) #works in almost any case
return Image.open(fp)
"""Timeseries"""
def _genTimeSeries(reduce_args, state):
import scikits.timeseries.tseries as ts
from numpy import ndarray
from numpy.ma import MaskedArray
time_series = ts._tsreconstruct(*reduce_args)
#from setstate modified
(ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state
#print 'regenerating %s' % dtyp
MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv))
_dates = time_series._dates
#_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ
ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm))
_dates.freq = frq
_dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None,
toobj=None, toord=None, tostr=None))
# Update the _optinfo dictionary
time_series._optinfo.update(infodict)
return time_series

View file

@ -44,7 +44,7 @@ u'/path'
<pyspark.conf.SparkConf object at ...>
>>> conf.get("spark.executorEnv.VAR1")
u'value1'
>>> print conf.toDebugString()
>>> print(conf.toDebugString())
spark.executorEnv.VAR1=value1
spark.executorEnv.VAR3=value3
spark.executorEnv.VAR4=value4
@ -56,6 +56,13 @@ spark.home=/path
__all__ = ['SparkConf']
import sys
import re
if sys.version > '3':
unicode = str
__doc__ = re.sub(r"(\W|^)[uU](['])", r'\1\2', __doc__)
class SparkConf(object):

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
import os
import shutil
import sys
@ -32,11 +34,14 @@ from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD, _load_from_socket
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler
if sys.version > '3':
xrange = range
__all__ = ['SparkContext']
@ -133,7 +138,7 @@ class SparkContext(object):
if sparkHome:
self._conf.setSparkHome(sparkHome)
if environment:
for key, value in environment.iteritems():
for key, value in environment.items():
self._conf.setExecutorEnv(key, value)
for key, value in DEFAULT_CONFIGS.items():
self._conf.setIfMissing(key, value)
@ -153,6 +158,10 @@ class SparkContext(object):
if k.startswith("spark.executorEnv."):
varName = k[len("spark.executorEnv."):]
self.environment[varName] = v
if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ:
# disable randomness of hash of string in worker, if this is not
# launched by spark-submit
self.environment["PYTHONHASHSEED"] = "0"
# Create the Java SparkContext through Py4J
self._jsc = jsc or self._initialize_context(self._conf._jconf)
@ -323,7 +332,7 @@ class SparkContext(object):
start0 = c[0]
def getStart(split):
return start0 + (split * size / numSlices) * step
return start0 + int((split * size / numSlices)) * step
def f(split, iterator):
return xrange(getStart(split), getStart(split + 1), step)
@ -357,6 +366,7 @@ class SparkContext(object):
minPartitions = minPartitions or self.defaultMinPartitions
return RDD(self._jsc.objectFile(name, minPartitions), self)
@ignore_unicode_prefix
def textFile(self, name, minPartitions=None, use_unicode=True):
"""
Read a text file from HDFS, a local file system (available on all
@ -369,7 +379,7 @@ class SparkContext(object):
>>> path = os.path.join(tempdir, "sample-text.txt")
>>> with open(path, "w") as testFile:
... testFile.write("Hello world!")
... _ = testFile.write("Hello world!")
>>> textFile = sc.textFile(path)
>>> textFile.collect()
[u'Hello world!']
@ -378,6 +388,7 @@ class SparkContext(object):
return RDD(self._jsc.textFile(name, minPartitions), self,
UTF8Deserializer(use_unicode))
@ignore_unicode_prefix
def wholeTextFiles(self, path, minPartitions=None, use_unicode=True):
"""
Read a directory of text files from HDFS, a local file system
@ -411,9 +422,9 @@ class SparkContext(object):
>>> dirPath = os.path.join(tempdir, "files")
>>> os.mkdir(dirPath)
>>> with open(os.path.join(dirPath, "1.txt"), "w") as file1:
... file1.write("1")
... _ = file1.write("1")
>>> with open(os.path.join(dirPath, "2.txt"), "w") as file2:
... file2.write("2")
... _ = file2.write("2")
>>> textFiles = sc.wholeTextFiles(dirPath)
>>> sorted(textFiles.collect())
[(u'.../1.txt', u'1'), (u'.../2.txt', u'2')]
@ -456,7 +467,7 @@ class SparkContext(object):
jm = self._jvm.java.util.HashMap()
if not d:
d = {}
for k, v in d.iteritems():
for k, v in d.items():
jm[k] = v
return jm
@ -608,6 +619,7 @@ class SparkContext(object):
jrdd = self._jsc.checkpointFile(name)
return RDD(jrdd, self, input_deserializer)
@ignore_unicode_prefix
def union(self, rdds):
"""
Build the union of a list of RDDs.
@ -618,7 +630,7 @@ class SparkContext(object):
>>> path = os.path.join(tempdir, "union-text.txt")
>>> with open(path, "w") as testFile:
... testFile.write("Hello")
... _ = testFile.write("Hello")
>>> textFile = sc.textFile(path)
>>> textFile.collect()
[u'Hello']
@ -677,7 +689,7 @@ class SparkContext(object):
>>> from pyspark import SparkFiles
>>> path = os.path.join(tempdir, "test.txt")
>>> with open(path, "w") as testFile:
... testFile.write("100")
... _ = testFile.write("100")
>>> sc.addFile(path)
>>> def func(iterator):
... with open(SparkFiles.get("test.txt")) as testFile:
@ -705,11 +717,13 @@ class SparkContext(object):
"""
self.addFile(path)
(dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix
if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
self._python_includes.append(filename)
# for tests in local mode
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
if sys.version > '3':
import importlib
importlib.invalidate_caches()
def setCheckpointDir(self, dirName):
"""
@ -744,7 +758,7 @@ class SparkContext(object):
The application can use L{SparkContext.cancelJobGroup} to cancel all
running jobs in this group.
>>> import thread, threading
>>> import threading
>>> from time import sleep
>>> result = "Not Set"
>>> lock = threading.Lock()
@ -763,10 +777,10 @@ class SparkContext(object):
... sleep(5)
... sc.cancelJobGroup("job_to_cancel")
>>> supress = lock.acquire()
>>> supress = thread.start_new_thread(start_job, (10,))
>>> supress = thread.start_new_thread(stop_job, tuple())
>>> supress = threading.Thread(target=start_job, args=(10,)).start()
>>> supress = threading.Thread(target=stop_job).start()
>>> supress = lock.acquire()
>>> print result
>>> print(result)
Cancelled
If interruptOnCancel is set to true for the job group, then job cancellation will result

View file

@ -24,9 +24,10 @@ import sys
import traceback
import time
import gc
from errno import EINTR, ECHILD, EAGAIN
from errno import EINTR, EAGAIN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
from pyspark.worker import main as worker_main
from pyspark.serializers import read_int, write_int
@ -53,8 +54,8 @@ def worker(sock):
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536)
exit_code = 0
try:
worker_main(infile, outfile)
@ -68,17 +69,6 @@ def worker(sock):
return exit_code
# Cleanup zombie children
def cleanup_dead_children():
try:
while True:
pid, _ = os.waitpid(0, os.WNOHANG)
if not pid:
break
except:
pass
def manager():
# Create a new process group to corral our children
os.setpgid(0, 0)
@ -88,8 +78,12 @@ def manager():
listen_sock.bind(('127.0.0.1', 0))
listen_sock.listen(max(1024, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname()
write_int(listen_port, sys.stdout)
sys.stdout.flush()
# re-open stdin/stdout in 'wb' mode
stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4)
stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4)
write_int(listen_port, stdout_bin)
stdout_bin.flush()
def shutdown(code):
signal.signal(SIGTERM, SIG_DFL)
@ -101,6 +95,7 @@ def manager():
shutdown(1)
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP
signal.signal(SIGCHLD, SIG_IGN)
reuse = os.environ.get("SPARK_REUSE_WORKER")
@ -115,12 +110,9 @@ def manager():
else:
raise
# cleanup in signal handler will cause deadlock
cleanup_dead_children()
if 0 in ready_fds:
try:
worker_pid = read_int(sys.stdin)
worker_pid = read_int(stdin_bin)
except EOFError:
# Spark told us to exit by closing stdin
shutdown(0)
@ -145,7 +137,7 @@ def manager():
time.sleep(1)
pid = os.fork() # error here will shutdown daemon
else:
outfile = sock.makefile('w')
outfile = sock.makefile(mode='wb')
write_int(e.errno, outfile) # Signal that the fork failed
outfile.flush()
outfile.close()
@ -157,7 +149,7 @@ def manager():
listen_sock.close()
try:
# Acknowledge that the fork was successful
outfile = sock.makefile("w")
outfile = sock.makefile(mode="wb")
write_int(os.getpid(), outfile)
outfile.flush()
outfile.close()

View file

@ -627,51 +627,49 @@ def merge(iterables, key=None, reverse=False):
if key is None:
for order, it in enumerate(map(iter, iterables)):
try:
next = it.next
h_append([next(), order * direction, next])
h_append([next(it), order * direction, it])
except StopIteration:
pass
_heapify(h)
while len(h) > 1:
try:
while True:
value, order, next = s = h[0]
value, order, it = s = h[0]
yield value
s[0] = next() # raises StopIteration when exhausted
s[0] = next(it) # raises StopIteration when exhausted
_heapreplace(h, s) # restore heap condition
except StopIteration:
_heappop(h) # remove empty iterator
if h:
# fast case when only a single iterator remains
value, order, next = h[0]
value, order, it = h[0]
yield value
for value in next.__self__:
for value in it:
yield value
return
for order, it in enumerate(map(iter, iterables)):
try:
next = it.next
value = next()
h_append([key(value), order * direction, value, next])
value = next(it)
h_append([key(value), order * direction, value, it])
except StopIteration:
pass
_heapify(h)
while len(h) > 1:
try:
while True:
key_value, order, value, next = s = h[0]
key_value, order, value, it = s = h[0]
yield value
value = next()
value = next(it)
s[0] = key(value)
s[2] = value
_heapreplace(h, s)
except StopIteration:
_heappop(h)
if h:
key_value, order, value, next = h[0]
key_value, order, value, it = h[0]
yield value
for value in next.__self__:
for value in it:
yield value

View file

@ -69,7 +69,7 @@ def launch_gateway():
if callback_socket in readable:
gateway_connection = callback_socket.accept()[0]
# Determine which ephemeral port the server started on:
gateway_port = read_int(gateway_connection.makefile())
gateway_port = read_int(gateway_connection.makefile(mode="rb"))
gateway_connection.close()
callback_socket.close()
if gateway_port is None:

View file

@ -32,6 +32,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from pyspark.resultiterable import ResultIterable
from functools import reduce
def _do_python_join(rdd, other, numPartitions, dispatch):

View file

@ -39,10 +39,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> lr = LogisticRegression(maxIter=5, regParam=0.01)
>>> model = lr.fit(df)
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
>>> print model.transform(test0).head().prediction
>>> model.transform(test0).head().prediction
0.0
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
>>> print model.transform(test1).head().prediction
>>> model.transform(test1).head().prediction
1.0
>>> lr.setParams("vector")
Traceback (most recent call last):

View file

@ -15,6 +15,7 @@
# limitations under the License.
#
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaTransformer
@ -24,6 +25,7 @@ __all__ = ['Tokenizer', 'HashingTF']
@inherit_doc
@ignore_unicode_prefix
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
A tokenizer that converts the input string to lowercase and then
@ -32,15 +34,15 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(text="a b c")]).toDF()
>>> tokenizer = Tokenizer(inputCol="text", outputCol="words")
>>> print tokenizer.transform(df).head()
>>> tokenizer.transform(df).head()
Row(text=u'a b c', words=[u'a', u'b', u'c'])
>>> # Change a parameter.
>>> print tokenizer.setParams(outputCol="tokens").transform(df).head()
>>> tokenizer.setParams(outputCol="tokens").transform(df).head()
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
>>> # Temporarily modify a parameter.
>>> print tokenizer.transform(df, {tokenizer.outputCol: "words"}).head()
>>> tokenizer.transform(df, {tokenizer.outputCol: "words"}).head()
Row(text=u'a b c', words=[u'a', u'b', u'c'])
>>> print tokenizer.transform(df).head()
>>> tokenizer.transform(df).head()
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
>>> # Must use keyword arguments to specify params.
>>> tokenizer.setParams("text")
@ -79,13 +81,13 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF()
>>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
>>> print hashingTF.transform(df).head().features
(10,[7,8,9],[1.0,1.0,1.0])
>>> print hashingTF.setParams(outputCol="freqs").transform(df).head().freqs
(10,[7,8,9],[1.0,1.0,1.0])
>>> hashingTF.transform(df).head().features
SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0})
>>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs
SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0})
>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
>>> print hashingTF.transform(df, params).head().vector
(5,[2,3,4],[1.0,1.0,1.0])
>>> hashingTF.transform(df, params).head().vector
SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0})
"""
_java_class = "org.apache.spark.ml.feature.HashingTF"

View file

@ -63,8 +63,8 @@ class Params(Identifiable):
uses :py:func:`dir` to get all attributes of type
:py:class:`Param`.
"""
return filter(lambda attr: isinstance(attr, Param),
[getattr(self, x) for x in dir(self) if x != "params"])
return list(filter(lambda attr: isinstance(attr, Param),
[getattr(self, x) for x in dir(self) if x != "params"]))
def _explain(self, param):
"""
@ -185,7 +185,7 @@ class Params(Identifiable):
"""
Sets user-supplied params.
"""
for param, value in kwargs.iteritems():
for param, value in kwargs.items():
self.paramMap[getattr(self, param)] = value
return self
@ -193,6 +193,6 @@ class Params(Identifiable):
"""
Sets default params.
"""
for param, value in kwargs.iteritems():
for param, value in kwargs.items():
self.defaultParamMap[getattr(self, param)] = value
return self

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
from __future__ import print_function
header = """#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
@ -82,9 +84,9 @@ def _gen_param_code(name, doc, defaultValueStr):
.replace("$defaultValueStr", str(defaultValueStr))
if __name__ == "__main__":
print header
print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n"
print "from pyspark.ml.param import Param, Params\n\n"
print(header)
print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n")
print("from pyspark.ml.param import Param, Params\n\n")
shared = [
("maxIter", "max number of iterations", None),
("regParam", "regularization constant", None),
@ -97,4 +99,4 @@ if __name__ == "__main__":
code = []
for name, doc, defaultValueStr in shared:
code.append(_gen_param_code(name, doc, defaultValueStr))
print "\n\n\n".join(code)
print("\n\n\n".join(code))

View file

@ -18,6 +18,7 @@
"""
Python bindings for MLlib.
"""
from __future__ import absolute_import
# MLlib currently needs NumPy 1.4+, so complain if lower
@ -29,7 +30,9 @@ __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random',
'recommendation', 'regression', 'stat', 'tree', 'util']
import sys
import rand as random
random.__name__ = 'random'
random.RandomRDDs.__module__ = __name__ + '.random'
sys.modules[__name__ + '.random'] = random
from . import rand as random
modname = __name__ + '.random'
random.__name__ = modname
random.RandomRDDs.__module__ = modname
sys.modules[modname] = random
del modname, sys

View file

@ -510,9 +510,10 @@ class NaiveBayesModel(Saveable, Loader):
def load(cls, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load(
sc._jsc.sc(), path)
py_labels = _java2py(sc, java_model.labels())
py_pi = _java2py(sc, java_model.pi())
py_theta = _java2py(sc, java_model.theta())
# Can not unpickle array.array from Pyrolite in Python3 with "bytes"
py_labels = _java2py(sc, java_model.labels(), "latin1")
py_pi = _java2py(sc, java_model.pi(), "latin1")
py_theta = _java2py(sc, java_model.theta(), "latin1")
return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta))

View file

@ -15,6 +15,12 @@
# limitations under the License.
#
import sys
import array as pyarray
if sys.version > '3':
xrange = range
from numpy import array
from pyspark import RDD
@ -55,8 +61,8 @@ class KMeansModel(Saveable, Loader):
True
>>> model.predict(sparse_data[2]) == model.predict(sparse_data[3])
True
>>> type(model.clusterCenters)
<type 'list'>
>>> isinstance(model.clusterCenters, list)
True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
@ -90,7 +96,7 @@ class KMeansModel(Saveable, Loader):
return best
def save(self, sc, path):
java_centers = _py2java(sc, map(_convert_to_vector, self.centers))
java_centers = _py2java(sc, [_convert_to_vector(c) for c in self.centers])
java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers)
java_model.save(sc._jsc.sc(), path)
@ -133,7 +139,7 @@ class GaussianMixtureModel(object):
... 5.7048, 4.6567, 5.5026,
... 4.5605, 5.2043, 6.2734]).reshape(5, 3))
>>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
... maxIterations=150, seed=10)
... maxIterations=150, seed=10)
>>> labels = model.predict(clusterdata_2).collect()
>>> labels[0]==labels[1]==labels[2]
True
@ -168,8 +174,8 @@ class GaussianMixtureModel(object):
if isinstance(x, RDD):
means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
self.weights, means, sigmas)
return membership_matrix
_convert_to_vector(self.weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
class GaussianMixture(object):

View file

@ -15,6 +15,11 @@
# limitations under the License.
#
import sys
if sys.version >= '3':
long = int
unicode = str
import py4j.protocol
from py4j.protocol import Py4JJavaError
from py4j.java_gateway import JavaObject
@ -36,7 +41,7 @@ _float_str_mapping = {
def _new_smart_decode(obj):
if isinstance(obj, float):
s = unicode(obj)
s = str(obj)
return _float_str_mapping.get(s, s)
return _old_smart_decode(obj)
@ -74,15 +79,15 @@ def _py2java(sc, obj):
obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
elif isinstance(obj, JavaObject):
pass
elif isinstance(obj, (int, long, float, bool, basestring)):
elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
pass
else:
bytes = bytearray(PickleSerializer().dumps(obj))
obj = sc._jvm.SerDe.loads(bytes)
data = bytearray(PickleSerializer().dumps(obj))
obj = sc._jvm.SerDe.loads(data)
return obj
def _java2py(sc, r):
def _java2py(sc, r, encoding="bytes"):
if isinstance(r, JavaObject):
clsName = r.getClass().getSimpleName()
# convert RDD into JavaRDD
@ -102,8 +107,8 @@ def _java2py(sc, r):
except Py4JJavaError:
pass # not pickable
if isinstance(r, bytearray):
r = PickleSerializer().loads(str(r))
if isinstance(r, (bytearray, bytes)):
r = PickleSerializer().loads(bytes(r), encoding=encoding)
return r

View file

@ -23,12 +23,17 @@ from __future__ import absolute_import
import sys
import warnings
import random
import binascii
if sys.version >= '3':
basestring = str
unicode = str
from py4j.protocol import Py4JJavaError
from pyspark import RDD, SparkContext
from pyspark import SparkContext
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector
from pyspark.mllib.linalg import Vectors, _convert_to_vector
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
@ -206,7 +211,7 @@ class HashingTF(object):
>>> htf = HashingTF(100)
>>> doc = "a a b b c d".split(" ")
>>> htf.transform(doc)
SparseVector(100, {1: 1.0, 14: 1.0, 31: 2.0, 44: 2.0})
SparseVector(100, {...})
"""
def __init__(self, numFeatures=1 << 20):
"""
@ -360,6 +365,7 @@ class Word2VecModel(JavaVectorTransformer):
return self.call("getVectors")
@ignore_unicode_prefix
class Word2Vec(object):
"""
Word2Vec creates vector representation of words in a text corpus.
@ -382,7 +388,7 @@ class Word2Vec(object):
>>> sentence = "a b " * 100 + "a c " * 10
>>> localDoc = [sentence, sentence]
>>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
>>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
>>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc)
>>> syms = model.findSynonyms("a", 2)
>>> [s[0] for s in syms]
@ -400,7 +406,7 @@ class Word2Vec(object):
self.learningRate = 0.025
self.numPartitions = 1
self.numIterations = 1
self.seed = random.randint(0, sys.maxint)
self.seed = random.randint(0, sys.maxsize)
self.minCount = 5
def setVectorSize(self, vectorSize):
@ -459,7 +465,7 @@ class Word2Vec(object):
raise TypeError("data should be an RDD of list of string")
jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
float(self.learningRate), int(self.numPartitions),
int(self.numIterations), long(self.seed),
int(self.numIterations), int(self.seed),
int(self.minCount))
return Word2VecModel(jmodel)

View file

@ -16,12 +16,14 @@
#
from pyspark import SparkContext
from pyspark.rdd import ignore_unicode_prefix
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
__all__ = ['FPGrowth', 'FPGrowthModel']
@inherit_doc
@ignore_unicode_prefix
class FPGrowthModel(JavaModelWrapper):
"""

View file

@ -25,7 +25,13 @@ SciPy is available in their environment.
import sys
import array
import copy_reg
if sys.version >= '3':
basestring = str
xrange = range
import copyreg as copy_reg
else:
import copy_reg
import numpy as np
@ -57,7 +63,7 @@ except:
def _convert_to_vector(l):
if isinstance(l, Vector):
return l
elif type(l) in (array.array, np.array, np.ndarray, list, tuple):
elif type(l) in (array.array, np.array, np.ndarray, list, tuple, xrange):
return DenseVector(l)
elif _have_scipy and scipy.sparse.issparse(l):
assert l.shape[1] == 1, "Expected column vector"
@ -88,7 +94,7 @@ def _vector_size(v):
"""
if isinstance(v, Vector):
return len(v)
elif type(v) in (array.array, list, tuple):
elif type(v) in (array.array, list, tuple, xrange):
return len(v)
elif type(v) == np.ndarray:
if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1):
@ -193,7 +199,7 @@ class DenseVector(Vector):
DenseVector([1.0, 0.0])
"""
def __init__(self, ar):
if isinstance(ar, basestring):
if isinstance(ar, bytes):
ar = np.frombuffer(ar, dtype=np.float64)
elif not isinstance(ar, np.ndarray):
ar = np.array(ar, dtype=np.float64)
@ -321,11 +327,13 @@ class DenseVector(Vector):
__sub__ = _delegate("__sub__")
__mul__ = _delegate("__mul__")
__div__ = _delegate("__div__")
__truediv__ = _delegate("__truediv__")
__mod__ = _delegate("__mod__")
__radd__ = _delegate("__radd__")
__rsub__ = _delegate("__rsub__")
__rmul__ = _delegate("__rmul__")
__rdiv__ = _delegate("__rdiv__")
__rtruediv__ = _delegate("__rtruediv__")
__rmod__ = _delegate("__rmod__")
@ -344,12 +352,12 @@ class SparseVector(Vector):
:param args: Non-zero entries, as a dictionary, list of tupes,
or two sorted lists containing indices and values.
>>> print SparseVector(4, {1: 1.0, 3: 5.5})
(4,[1,3],[1.0,5.5])
>>> print SparseVector(4, [(1, 1.0), (3, 5.5)])
(4,[1,3],[1.0,5.5])
>>> print SparseVector(4, [1, 3], [1.0, 5.5])
(4,[1,3],[1.0,5.5])
>>> SparseVector(4, {1: 1.0, 3: 5.5})
SparseVector(4, {1: 1.0, 3: 5.5})
>>> SparseVector(4, [(1, 1.0), (3, 5.5)])
SparseVector(4, {1: 1.0, 3: 5.5})
>>> SparseVector(4, [1, 3], [1.0, 5.5])
SparseVector(4, {1: 1.0, 3: 5.5})
"""
self.size = int(size)
assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments"
@ -361,8 +369,8 @@ class SparseVector(Vector):
self.indices = np.array([p[0] for p in pairs], dtype=np.int32)
self.values = np.array([p[1] for p in pairs], dtype=np.float64)
else:
if isinstance(args[0], basestring):
assert isinstance(args[1], str), "values should be string too"
if isinstance(args[0], bytes):
assert isinstance(args[1], bytes), "values should be string too"
if args[0]:
self.indices = np.frombuffer(args[0], np.int32)
self.values = np.frombuffer(args[1], np.float64)
@ -591,12 +599,12 @@ class Vectors(object):
:param args: Non-zero entries, as a dictionary, list of tupes,
or two sorted lists containing indices and values.
>>> print Vectors.sparse(4, {1: 1.0, 3: 5.5})
(4,[1,3],[1.0,5.5])
>>> print Vectors.sparse(4, [(1, 1.0), (3, 5.5)])
(4,[1,3],[1.0,5.5])
>>> print Vectors.sparse(4, [1, 3], [1.0, 5.5])
(4,[1,3],[1.0,5.5])
>>> Vectors.sparse(4, {1: 1.0, 3: 5.5})
SparseVector(4, {1: 1.0, 3: 5.5})
>>> Vectors.sparse(4, [(1, 1.0), (3, 5.5)])
SparseVector(4, {1: 1.0, 3: 5.5})
>>> Vectors.sparse(4, [1, 3], [1.0, 5.5])
SparseVector(4, {1: 1.0, 3: 5.5})
"""
return SparseVector(size, *args)
@ -645,7 +653,7 @@ class Matrix(object):
"""
Convert Matrix attributes which are array-like or buffer to array.
"""
if isinstance(array_like, basestring):
if isinstance(array_like, bytes):
return np.frombuffer(array_like, dtype=dtype)
return np.asarray(array_like, dtype=dtype)
@ -677,7 +685,7 @@ class DenseMatrix(Matrix):
def toSparse(self):
"""Convert to SparseMatrix"""
indices = np.nonzero(self.values)[0]
colCounts = np.bincount(indices / self.numRows)
colCounts = np.bincount(indices // self.numRows)
colPtrs = np.cumsum(np.hstack(
(0, colCounts, np.zeros(self.numCols - colCounts.size))))
values = self.values[indices]

View file

@ -88,10 +88,10 @@ class RandomRDDs(object):
:param seed: Random seed (default: a random long integer).
:return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0).
>>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L)
>>> x = RandomRDDs.normalRDD(sc, 1000, seed=1)
>>> stats = x.stats()
>>> stats.count()
1000L
1000
>>> abs(stats.mean() - 0.0) < 0.1
True
>>> abs(stats.stdev() - 1.0) < 0.1
@ -118,10 +118,10 @@ class RandomRDDs(object):
>>> std = 1.0
>>> expMean = exp(mean + 0.5 * std * std)
>>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std))
>>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2L)
>>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2)
>>> stats = x.stats()
>>> stats.count()
1000L
1000
>>> abs(stats.mean() - expMean) < 0.5
True
>>> from math import sqrt
@ -145,10 +145,10 @@ class RandomRDDs(object):
:return: RDD of float comprised of i.i.d. samples ~ Pois(mean).
>>> mean = 100.0
>>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L)
>>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2)
>>> stats = x.stats()
>>> stats.count()
1000L
1000
>>> abs(stats.mean() - mean) < 0.5
True
>>> from math import sqrt
@ -171,10 +171,10 @@ class RandomRDDs(object):
:return: RDD of float comprised of i.i.d. samples ~ Exp(mean).
>>> mean = 2.0
>>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2L)
>>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2)
>>> stats = x.stats()
>>> stats.count()
1000L
1000
>>> abs(stats.mean() - mean) < 0.5
True
>>> from math import sqrt
@ -202,10 +202,10 @@ class RandomRDDs(object):
>>> scale = 2.0
>>> expMean = shape * scale
>>> expStd = sqrt(shape * scale * scale)
>>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2L)
>>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2)
>>> stats = x.stats()
>>> stats.count()
1000L
1000
>>> abs(stats.mean() - expMean) < 0.5
True
>>> abs(stats.stdev() - expStd) < 0.5
@ -254,7 +254,7 @@ class RandomRDDs(object):
:return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`.
>>> import numpy as np
>>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect())
>>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect())
>>> mat.shape
(100, 100)
>>> abs(mat.mean() - 0.0) < 0.1
@ -286,8 +286,8 @@ class RandomRDDs(object):
>>> std = 1.0
>>> expMean = exp(mean + 0.5 * std * std)
>>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std))
>>> mat = np.matrix(RandomRDDs.logNormalVectorRDD(sc, mean, std, \
100, 100, seed=1L).collect())
>>> m = RandomRDDs.logNormalVectorRDD(sc, mean, std, 100, 100, seed=1).collect()
>>> mat = np.matrix(m)
>>> mat.shape
(100, 100)
>>> abs(mat.mean() - expMean) < 0.1
@ -315,7 +315,7 @@ class RandomRDDs(object):
>>> import numpy as np
>>> mean = 100.0
>>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L)
>>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1)
>>> mat = np.mat(rdd.collect())
>>> mat.shape
(100, 100)
@ -345,7 +345,7 @@ class RandomRDDs(object):
>>> import numpy as np
>>> mean = 0.5
>>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1L)
>>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1)
>>> mat = np.mat(rdd.collect())
>>> mat.shape
(100, 100)
@ -380,8 +380,7 @@ class RandomRDDs(object):
>>> scale = 2.0
>>> expMean = shape * scale
>>> expStd = sqrt(shape * scale * scale)
>>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, \
100, 100, seed=1L).collect())
>>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect())
>>> mat.shape
(100, 100)
>>> abs(mat.mean() - expMean) < 0.1

View file

@ -15,6 +15,7 @@
# limitations under the License.
#
import array
from collections import namedtuple
from pyspark import SparkContext
@ -104,14 +105,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)"
first = user_product.first()
assert len(first) == 2, "user_product should be RDD of (user, product)"
user_product = user_product.map(lambda (u, p): (int(u), int(p)))
user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1])))
return self.call("predict", user_product)
def userFeatures(self):
return self.call("getUserFeatures")
return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v))
def productFeatures(self):
return self.call("getProductFeatures")
return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v))
@classmethod
def load(cls, sc, path):

View file

@ -15,7 +15,7 @@
# limitations under the License.
#
from pyspark import RDD
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
from pyspark.mllib.linalg import Matrix, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
@ -38,7 +38,7 @@ class MultivariateStatisticalSummary(JavaModelWrapper):
return self.call("variance").toArray()
def count(self):
return self.call("count")
return int(self.call("count"))
def numNonzeros(self):
return self.call("numNonzeros").toArray()
@ -78,7 +78,7 @@ class Statistics(object):
>>> cStats.variance()
array([ 4., 13., 0., 25.])
>>> cStats.count()
3L
3
>>> cStats.numNonzeros()
array([ 3., 2., 0., 3.])
>>> cStats.max()
@ -124,20 +124,20 @@ class Statistics(object):
>>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]),
... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])])
>>> pearsonCorr = Statistics.corr(rdd)
>>> print str(pearsonCorr).replace('nan', 'NaN')
>>> print(str(pearsonCorr).replace('nan', 'NaN'))
[[ 1. 0.05564149 NaN 0.40047142]
[ 0.05564149 1. NaN 0.91359586]
[ NaN NaN 1. NaN]
[ 0.40047142 0.91359586 NaN 1. ]]
>>> spearmanCorr = Statistics.corr(rdd, method="spearman")
>>> print str(spearmanCorr).replace('nan', 'NaN')
>>> print(str(spearmanCorr).replace('nan', 'NaN'))
[[ 1. 0.10540926 NaN 0.4 ]
[ 0.10540926 1. NaN 0.9486833 ]
[ NaN NaN 1. NaN]
[ 0.4 0.9486833 NaN 1. ]]
>>> try:
... Statistics.corr(rdd, "spearman")
... print "Method name as second argument without 'method=' shouldn't be allowed."
... print("Method name as second argument without 'method=' shouldn't be allowed.")
... except TypeError:
... pass
"""
@ -153,6 +153,7 @@ class Statistics(object):
return callMLlibFunc("corr", x.map(float), y.map(float), method)
@staticmethod
@ignore_unicode_prefix
def chiSqTest(observed, expected=None):
"""
.. note:: Experimental
@ -188,11 +189,11 @@ class Statistics(object):
>>> from pyspark.mllib.linalg import Vectors, Matrices
>>> observed = Vectors.dense([4, 6, 5])
>>> pearson = Statistics.chiSqTest(observed)
>>> print pearson.statistic
>>> print(pearson.statistic)
0.4
>>> pearson.degreesOfFreedom
2
>>> print round(pearson.pValue, 4)
>>> print(round(pearson.pValue, 4))
0.8187
>>> pearson.method
u'pearson'
@ -202,12 +203,12 @@ class Statistics(object):
>>> observed = Vectors.dense([21, 38, 43, 80])
>>> expected = Vectors.dense([3, 5, 7, 20])
>>> pearson = Statistics.chiSqTest(observed, expected)
>>> print round(pearson.pValue, 4)
>>> print(round(pearson.pValue, 4))
0.0027
>>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0]
>>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data))
>>> print round(chi.statistic, 4)
>>> print(round(chi.statistic, 4))
21.9958
>>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])),
@ -218,9 +219,9 @@ class Statistics(object):
... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),]
>>> rdd = sc.parallelize(data, 4)
>>> chi = Statistics.chiSqTest(rdd)
>>> print chi[0].statistic
>>> print(chi[0].statistic)
0.75
>>> print chi[1].statistic
>>> print(chi[1].statistic)
1.5
"""
if isinstance(observed, RDD):

View file

@ -72,11 +72,11 @@ class VectorTests(PySparkTestCase):
def _test_serialize(self, v):
self.assertEqual(v, ser.loads(ser.dumps(v)))
jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec)))
nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec)))
self.assertEqual(v, nv)
vs = [v] * 100
jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs)))
nvs = ser.loads(str(self.sc._jvm.SerDe.dumps(jvecs)))
nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs)))
self.assertEqual(vs, nvs)
def test_serialize(self):
@ -412,11 +412,11 @@ class StatTests(PySparkTestCase):
self.assertEqual(10, len(summary.normL1()))
self.assertEqual(10, len(summary.normL2()))
data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x))
data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x))
summary2 = Statistics.colStats(data2)
self.assertEqual(array([45.0]), summary2.normL1())
import math
expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, xrange(10))))
expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10))))
self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14)
@ -438,11 +438,11 @@ class VectorUDTTests(PySparkTestCase):
def test_infer_schema(self):
sqlCtx = SQLContext(self.sc)
rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
srdd = sqlCtx.inferSchema(rdd)
schema = srdd.schema
df = rdd.toDF()
schema = df.schema
field = [f for f in schema.fields if f.name == "features"][0]
self.assertEqual(field.dataType, self.udt)
vectors = srdd.map(lambda p: p.features).collect()
vectors = df.map(lambda p: p.features).collect()
self.assertEqual(len(vectors), 2)
for v in vectors:
if isinstance(v, SparseVector):
@ -695,7 +695,7 @@ class ChiSqTestTests(PySparkTestCase):
class SerDeTest(PySparkTestCase):
def test_to_java_object_rdd(self): # SPARK-6660
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0)
self.assertEqual(_to_java_object_rdd(data).count(), 10)
@ -771,7 +771,7 @@ class StandardScalerTests(PySparkTestCase):
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
unittest.main()
if not _have_scipy:
print "NOTE: SciPy tests were skipped as it does not seem to be installed"
print("NOTE: SciPy tests were skipped as it does not seem to be installed")

View file

@ -163,14 +163,16 @@ class DecisionTree(object):
... LabeledPoint(1.0, [3.0])
... ]
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
>>> print model, # it already has newline
>>> print(model)
DecisionTreeModel classifier of depth 1 with 3 nodes
>>> print model.toDebugString(), # it already has newline
>>> print(model.toDebugString())
DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.0)
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
<BLANKLINE>
>>> model.predict(array([1.0]))
1.0
>>> model.predict(array([0.0]))
@ -318,9 +320,10 @@ class RandomForest(object):
3
>>> model.totalNumNodes()
7
>>> print model,
>>> print(model)
TreeEnsembleModel classifier with 3 trees
>>> print model.toDebugString(),
<BLANKLINE>
>>> print(model.toDebugString())
TreeEnsembleModel classifier with 3 trees
<BLANKLINE>
Tree 0:
@ -335,6 +338,7 @@ class RandomForest(object):
Predict: 0.0
Else (feature 0 > 1.0)
Predict: 1.0
<BLANKLINE>
>>> model.predict([2.0])
1.0
>>> model.predict([0.0])
@ -483,8 +487,9 @@ class GradientBoostedTrees(object):
100
>>> model.totalNumNodes()
300
>>> print model, # it already has newline
>>> print(model) # it already has newline
TreeEnsembleModel classifier with 100 trees
<BLANKLINE>
>>> model.predict([2.0])
1.0
>>> model.predict([0.0])

View file

@ -15,10 +15,14 @@
# limitations under the License.
#
import sys
import numpy as np
import warnings
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
if sys.version > '3':
xrange = range
from pyspark.mllib.common import callMLlibFunc, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
@ -94,22 +98,16 @@ class MLUtils(object):
>>> from pyspark.mllib.util import MLUtils
>>> from pyspark.mllib.regression import LabeledPoint
>>> tempFile = NamedTemporaryFile(delete=True)
>>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
>>> _ = tempFile.write(b"+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
>>> tempFile.flush()
>>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect()
>>> tempFile.close()
>>> type(examples[0]) == LabeledPoint
True
>>> print examples[0]
(1.0,(6,[0,2,4],[1.0,2.0,3.0]))
>>> type(examples[1]) == LabeledPoint
True
>>> print examples[1]
(-1.0,(6,[],[]))
>>> type(examples[2]) == LabeledPoint
True
>>> print examples[2]
(-1.0,(6,[1,3,5],[4.0,5.0,6.0]))
>>> examples[0]
LabeledPoint(1.0, (6,[0,2,4],[1.0,2.0,3.0]))
>>> examples[1]
LabeledPoint(-1.0, (6,[],[]))
>>> examples[2]
LabeledPoint(-1.0, (6,[1,3,5],[4.0,5.0,6.0]))
"""
from pyspark.mllib.regression import LabeledPoint
if multiclass is not None:

View file

@ -84,11 +84,11 @@ class Profiler(object):
>>> from pyspark import BasicProfiler
>>> class MyCustomProfiler(BasicProfiler):
... def show(self, id):
... print "My custom profiles for RDD:%s" % id
... print("My custom profiles for RDD:%s" % id)
...
>>> conf = SparkConf().set("spark.python.profile", "true")
>>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
>>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10)
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
>>> sc.show_profiles()
My custom profiles for RDD:1
@ -111,9 +111,9 @@ class Profiler(object):
""" Print the profile stats to stdout, id is the RDD id """
stats = self.stats()
if stats:
print "=" * 60
print "Profile of RDD<id=%d>" % id
print "=" * 60
print("=" * 60)
print("Profile of RDD<id=%d>" % id)
print("=" * 60)
stats.sort_stats("time", "cumulative").print_stats()
def dump(self, id, path):

View file

@ -16,21 +16,29 @@
#
import copy
from collections import defaultdict
from itertools import chain, ifilter, imap
import operator
import sys
import os
import re
import operator
import shlex
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
import warnings
import heapq
import bisect
import random
import socket
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
from collections import defaultdict
from itertools import chain
from functools import reduce
from math import sqrt, log, isinf, isnan, pow, ceil
if sys.version > '3':
basestring = unicode = str
else:
from itertools import imap as map, ifilter as filter
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
@ -50,20 +58,21 @@ from py4j.java_collections import ListConverter, MapConverter
__all__ = ["RDD"]
# TODO: for Python 3.3+, PYTHONHASHSEED should be reset to disable randomized
# hash for string
def portable_hash(x):
"""
This function returns consistant hash code for builtin types, especially
This function returns consistent hash code for builtin types, especially
for None and tuple with None.
The algrithm is similar to that one used by CPython 2.7
The algorithm is similar to that one used by CPython 2.7
>>> portable_hash(None)
0
>>> portable_hash((None, 1)) & 0xffffffff
219750521
"""
if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ:
raise Exception("Randomness of hash of string should be disabled via PYTHONHASHSEED")
if x is None:
return 0
if isinstance(x, tuple):
@ -71,7 +80,7 @@ def portable_hash(x):
for i in x:
h ^= portable_hash(i)
h *= 1000003
h &= sys.maxint
h &= sys.maxsize
h ^= len(x)
if h == -1:
h = -2
@ -123,6 +132,19 @@ def _load_from_socket(port, serializer):
sock.close()
def ignore_unicode_prefix(f):
"""
Ignore the 'u' prefix of string in doc tests, to make it works
in both python 2 and 3
"""
if sys.version >= '3':
# the representation of unicode string in Python 3 does not have prefix 'u',
# so remove the prefix 'u' for doc tests
literal_re = re.compile(r"(\W|^)[uU](['])", re.UNICODE)
f.__doc__ = literal_re.sub(r'\1\2', f.__doc__)
return f
class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
@ -251,7 +273,7 @@ class RDD(object):
[('a', 1), ('b', 1), ('c', 1)]
"""
def func(_, iterator):
return imap(f, iterator)
return map(f, iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)
def flatMap(self, f, preservesPartitioning=False):
@ -266,7 +288,7 @@ class RDD(object):
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(s, iterator):
return chain.from_iterable(imap(f, iterator))
return chain.from_iterable(map(f, iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)
def mapPartitions(self, f, preservesPartitioning=False):
@ -329,7 +351,7 @@ class RDD(object):
[2, 4]
"""
def func(iterator):
return ifilter(f, iterator)
return filter(f, iterator)
return self.mapPartitions(func, True)
def distinct(self, numPartitions=None):
@ -341,7 +363,7 @@ class RDD(object):
"""
return self.map(lambda x: (x, None)) \
.reduceByKey(lambda x, _: x, numPartitions) \
.map(lambda (x, _): x)
.map(lambda x: x[0])
def sample(self, withReplacement, fraction, seed=None):
"""
@ -354,8 +376,8 @@ class RDD(object):
:param seed: seed for the random number generator
>>> rdd = sc.parallelize(range(100), 4)
>>> rdd.sample(False, 0.1, 81).count()
10
>>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14
True
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
@ -368,12 +390,14 @@ class RDD(object):
:param seed: random seed
:return: split RDDs in a list
>>> rdd = sc.parallelize(range(5), 1)
>>> rdd = sc.parallelize(range(500), 1)
>>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17)
>>> rdd1.collect()
[1, 3]
>>> rdd2.collect()
[0, 2, 4]
>>> len(rdd1.collect() + rdd2.collect())
500
>>> 150 < rdd1.count() < 250
True
>>> 250 < rdd2.count() < 350
True
"""
s = float(sum(weights))
cweights = [0.0]
@ -416,7 +440,7 @@ class RDD(object):
rand.shuffle(samples)
return samples
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize))
if num > maxSampleSize:
raise ValueError(
"Sample size cannot be greater than %d." % maxSampleSize)
@ -430,7 +454,7 @@ class RDD(object):
# See: scala/spark/RDD.scala
while len(samples) < num:
# TODO: add log warning for when more than one iteration was run
seed = rand.randint(0, sys.maxint)
seed = rand.randint(0, sys.maxsize)
samples = self.sample(withReplacement, fraction, seed).collect()
rand.shuffle(samples)
@ -507,7 +531,7 @@ class RDD(object):
"""
return self.map(lambda v: (v, None)) \
.cogroup(other.map(lambda v: (v, None))) \
.filter(lambda (k, vs): all(vs)) \
.filter(lambda k_vs: all(k_vs[1])) \
.keys()
def _reserialize(self, serializer=None):
@ -549,7 +573,7 @@ class RDD(object):
def sortPartition(iterator):
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))
return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending)))
return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True)
@ -579,7 +603,7 @@ class RDD(object):
def sortPartition(iterator):
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))
return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending)))
if numPartitions == 1:
if self.getNumPartitions() > 1:
@ -594,12 +618,12 @@ class RDD(object):
return self # empty RDD
maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
samples = self.sample(False, fraction, 1).map(lambda kv: kv[0]).collect()
samples = sorted(samples, key=keyfunc)
# we have numPartitions many parts but one of the them has
# an implicit boundary
bounds = [samples[len(samples) * (i + 1) / numPartitions]
bounds = [samples[int(len(samples) * (i + 1) / numPartitions)]
for i in range(0, numPartitions - 1)]
def rangePartitioner(k):
@ -662,12 +686,13 @@ class RDD(object):
"""
return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
@ignore_unicode_prefix
def pipe(self, command, env={}):
"""
Return an RDD created by piping elements to a forked external process.
>>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect()
['1', '2', '', '3']
[u'1', u'2', u'', u'3']
"""
def func(iterator):
pipe = Popen(
@ -675,17 +700,18 @@ class RDD(object):
def pipe_objs(out):
for obj in iterator:
out.write(str(obj).rstrip('\n') + '\n')
s = str(obj).rstrip('\n') + '\n'
out.write(s.encode('utf-8'))
out.close()
Thread(target=pipe_objs, args=[pipe.stdin]).start()
return (x.rstrip('\n') for x in iter(pipe.stdout.readline, ''))
return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b''))
return self.mapPartitions(func)
def foreach(self, f):
"""
Applies a function to all elements of this RDD.
>>> def f(x): print x
>>> def f(x): print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
def processPartition(iterator):
@ -700,7 +726,7 @@ class RDD(object):
>>> def f(iterator):
... for x in iterator:
... print x
... print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f)
"""
def func(it):
@ -874,7 +900,7 @@ class RDD(object):
# aggregation.
while numPartitions > scale + numPartitions / scale:
numPartitions /= scale
curNumPartitions = numPartitions
curNumPartitions = int(numPartitions)
def mapPartition(i, iterator):
for obj in iterator:
@ -984,7 +1010,7 @@ class RDD(object):
(('a', 'b', 'c'), [2, 2])
"""
if isinstance(buckets, (int, long)):
if isinstance(buckets, int):
if buckets < 1:
raise ValueError("number of buckets must be >= 1")
@ -1020,6 +1046,7 @@ class RDD(object):
raise ValueError("Can not generate buckets with infinite value")
# keep them as integer if possible
inc = int(inc)
if inc * buckets != maxv - minv:
inc = (maxv - minv) * 1.0 / buckets
@ -1137,7 +1164,7 @@ class RDD(object):
yield counts
def mergeMaps(m1, m2):
for k, v in m2.iteritems():
for k, v in m2.items():
m1[k] += v
return m1
return self.mapPartitions(countPartition).reduce(mergeMaps)
@ -1378,8 +1405,8 @@ class RDD(object):
>>> tmpFile = NamedTemporaryFile(delete=True)
>>> tmpFile.close()
>>> sc.parallelize([1, 2, 'spark', 'rdd']).saveAsPickleFile(tmpFile.name, 3)
>>> sorted(sc.pickleFile(tmpFile.name, 5).collect())
[1, 2, 'rdd', 'spark']
>>> sorted(sc.pickleFile(tmpFile.name, 5).map(str).collect())
['1', '2', 'rdd', 'spark']
"""
if batchSize == 0:
ser = AutoBatchedSerializer(PickleSerializer())
@ -1387,6 +1414,7 @@ class RDD(object):
ser = BatchedSerializer(PickleSerializer(), batchSize)
self._reserialize(ser)._jrdd.saveAsObjectFile(path)
@ignore_unicode_prefix
def saveAsTextFile(self, path, compressionCodecClass=None):
"""
Save this RDD as a text file, using string representations of elements.
@ -1418,12 +1446,13 @@ class RDD(object):
>>> codec = "org.apache.hadoop.io.compress.GzipCodec"
>>> sc.parallelize(['foo', 'bar']).saveAsTextFile(tempFile3.name, codec)
>>> from fileinput import input, hook_compressed
>>> ''.join(sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed)))
'bar\\nfoo\\n'
>>> result = sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed))
>>> b''.join(result).decode('utf-8')
u'bar\\nfoo\\n'
"""
def func(split, iterator):
for x in iterator:
if not isinstance(x, basestring):
if not isinstance(x, (unicode, bytes)):
x = unicode(x)
if isinstance(x, unicode):
x = x.encode("utf-8")
@ -1458,7 +1487,7 @@ class RDD(object):
>>> m.collect()
[1, 3]
"""
return self.map(lambda (k, v): k)
return self.map(lambda x: x[0])
def values(self):
"""
@ -1468,7 +1497,7 @@ class RDD(object):
>>> m.collect()
[2, 4]
"""
return self.map(lambda (k, v): v)
return self.map(lambda x: x[1])
def reduceByKey(self, func, numPartitions=None):
"""
@ -1507,7 +1536,7 @@ class RDD(object):
yield m
def mergeMaps(m1, m2):
for k, v in m2.iteritems():
for k, v in m2.items():
m1[k] = func(m1[k], v) if k in m1 else v
return m1
return self.mapPartitions(reducePartition).reduce(mergeMaps)
@ -1604,8 +1633,8 @@ class RDD(object):
>>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
>>> sets = pairs.partitionBy(2).glom().collect()
>>> set(sets[0]).intersection(set(sets[1]))
set([])
>>> len(set(sets[0]).intersection(set(sets[1])))
0
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
@ -1637,22 +1666,22 @@ class RDD(object):
if (c % 1000 == 0 and get_used_memory() > limit
or c > batch):
n, size = len(buckets), 0
for split in buckets.keys():
for split in list(buckets.keys()):
yield pack_long(split)
d = outputSerializer.dumps(buckets[split])
del buckets[split]
yield d
size += len(d)
avg = (size / n) >> 20
avg = int(size / n) >> 20
# let 1M < avg < 10M
if avg < 1:
batch *= 1.5
elif avg > 10:
batch = max(batch / 1.5, 1)
batch = max(int(batch / 1.5), 1)
c = 0
for split, items in buckets.iteritems():
for split, items in buckets.items():
yield pack_long(split)
yield outputSerializer.dumps(items)
@ -1707,7 +1736,7 @@ class RDD(object):
merger = ExternalMerger(agg, memory * 0.9, serializer) \
if spill else InMemoryMerger(agg)
merger.mergeValues(iterator)
return merger.iteritems()
return merger.items()
locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True)
shuffled = locally_combined.partitionBy(numPartitions)
@ -1716,7 +1745,7 @@ class RDD(object):
merger = ExternalMerger(agg, memory, serializer) \
if spill else InMemoryMerger(agg)
merger.mergeCombiners(iterator)
return merger.iteritems()
return merger.items()
return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True)
@ -1745,7 +1774,7 @@ class RDD(object):
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> from operator import add
>>> rdd.foldByKey(0, add).collect()
>>> sorted(rdd.foldByKey(0, add).collect())
[('a', 2), ('b', 1)]
"""
def createZero():
@ -1769,10 +1798,10 @@ class RDD(object):
sum or average) over each key, using reduceByKey or aggregateByKey will
provide much better performance.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(x.groupByKey().mapValues(len).collect())
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(rdd.groupByKey().mapValues(len).collect())
[('a', 2), ('b', 1)]
>>> sorted(x.groupByKey().mapValues(list).collect())
>>> sorted(rdd.groupByKey().mapValues(list).collect())
[('a', [1, 1]), ('b', [1])]
"""
def createCombiner(x):
@ -1795,7 +1824,7 @@ class RDD(object):
merger = ExternalMerger(agg, memory * 0.9, serializer) \
if spill else InMemoryMerger(agg)
merger.mergeValues(iterator)
return merger.iteritems()
return merger.items()
locally_combined = self.mapPartitions(combine, preservesPartitioning=True)
shuffled = locally_combined.partitionBy(numPartitions)
@ -1804,7 +1833,7 @@ class RDD(object):
merger = ExternalGroupBy(agg, memory, serializer)\
if spill else InMemoryMerger(agg)
merger.mergeCombiners(it)
return merger.iteritems()
return merger.items()
return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable)
@ -1819,7 +1848,7 @@ class RDD(object):
>>> x.flatMapValues(f).collect()
[('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')]
"""
flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1]))
return self.flatMap(flat_map_fn, preservesPartitioning=True)
def mapValues(self, f):
@ -1833,7 +1862,7 @@ class RDD(object):
>>> x.mapValues(f).collect()
[('a', 3), ('b', 1)]
"""
map_values_fn = lambda (k, v): (k, f(v))
map_values_fn = lambda kv: (kv[0], f(kv[1]))
return self.map(map_values_fn, preservesPartitioning=True)
def groupWith(self, other, *others):
@ -1844,8 +1873,7 @@ class RDD(object):
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> z = sc.parallelize([("b", 42)])
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \
sorted(list(w.groupWith(x, y, z).collect())))
>>> [(x, tuple(map(list, y))) for x, y in sorted(list(w.groupWith(x, y, z).collect()))]
[('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))]
"""
@ -1860,7 +1888,7 @@ class RDD(object):
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect())))
>>> [(x, tuple(map(list, y))) for x, y in sorted(list(x.cogroup(y).collect()))]
[('a', ([1], [2])), ('b', ([4], []))]
"""
return python_cogroup((self, other), numPartitions)
@ -1896,8 +1924,9 @@ class RDD(object):
>>> sorted(x.subtractByKey(y).collect())
[('b', 4), ('b', 5)]
"""
def filter_func((key, vals)):
return vals[0] and not vals[1]
def filter_func(pair):
key, (val1, val2) = pair
return val1 and not val2
return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0])
def subtract(self, other, numPartitions=None):
@ -1919,8 +1948,8 @@ class RDD(object):
>>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x)
>>> y = sc.parallelize(zip(range(0,5), range(0,5)))
>>> map((lambda (x,y): (x, (list(y[0]), (list(y[1]))))), sorted(x.cogroup(y).collect()))
[(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))]
>>> [(x, list(map(list, y))) for x, y in sorted(x.cogroup(y).collect())]
[(0, [[0], [0]]), (1, [[1], [1]]), (2, [[], [2]]), (3, [[], [3]]), (4, [[2], [4]])]
"""
return self.map(lambda x: (f(x), x))
@ -2049,17 +2078,18 @@ class RDD(object):
"""
Return the name of this RDD.
"""
name_ = self._jrdd.name()
if name_:
return name_.encode('utf-8')
n = self._jrdd.name()
if n:
return n
@ignore_unicode_prefix
def setName(self, name):
"""
Assign a name to this RDD.
>>> rdd1 = sc.parallelize([1,2])
>>> rdd1 = sc.parallelize([1, 2])
>>> rdd1.setName('RDD1').name()
'RDD1'
u'RDD1'
"""
self._jrdd.setName(name)
return self
@ -2121,7 +2151,7 @@ class RDD(object):
>>> sorted.lookup(1024)
[]
"""
values = self.filter(lambda (k, v): k == key).values()
values = self.filter(lambda kv: kv[0] == key).values()
if self.partitioner is not None:
return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False)
@ -2159,7 +2189,7 @@ class RDD(object):
or meet the confidence.
>>> rdd = sc.parallelize(range(1000), 10)
>>> r = sum(xrange(1000))
>>> r = sum(range(1000))
>>> (rdd.sumApprox(1000) - r) / r < 0.05
True
"""
@ -2176,7 +2206,7 @@ class RDD(object):
or meet the confidence.
>>> rdd = sc.parallelize(range(1000), 10)
>>> r = sum(xrange(1000)) / 1000.0
>>> r = sum(range(1000)) / 1000.0
>>> (rdd.meanApprox(1000) - r) / r < 0.05
True
"""
@ -2201,10 +2231,10 @@ class RDD(object):
It must be greater than 0.000017.
>>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct()
>>> 950 < n < 1050
>>> 900 < n < 1100
True
>>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct()
>>> 18 < n < 22
>>> 16 < n < 24
True
"""
if relativeSD < 0.000017:
@ -2223,8 +2253,7 @@ class RDD(object):
>>> [x for x in rdd.toLocalIterator()]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
partitions = xrange(self.getNumPartitions())
for partition in partitions:
for partition in range(self.getNumPartitions()):
rows = self.context.runJob(self, lambda x: x, [partition])
for row in rows:
yield row

View file

@ -23,7 +23,7 @@ import math
class RDDSamplerBase(object):
def __init__(self, withReplacement, seed=None):
self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._seed = seed if seed is not None else random.randint(0, sys.maxsize)
self._withReplacement = withReplacement
self._random = None
@ -31,7 +31,7 @@ class RDDSamplerBase(object):
self._random = random.Random(self._seed ^ split)
# mixing because the initial seeds are close to each other
for _ in xrange(10):
for _ in range(10):
self._random.randint(0, 1)
def getUniformSample(self):

View file

@ -49,16 +49,24 @@ which contains two batches of two objects:
>>> sc.stop()
"""
import cPickle
from itertools import chain, izip, product
import sys
from itertools import chain, product
import marshal
import struct
import sys
import types
import collections
import zlib
import itertools
if sys.version < '3':
import cPickle as pickle
protocol = 2
from itertools import izip as zip
else:
import pickle
protocol = 3
xrange = range
from pyspark import cloudpickle
@ -97,7 +105,7 @@ class Serializer(object):
# subclasses should override __eq__ as appropriate.
def __eq__(self, other):
return isinstance(other, self.__class__)
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
@ -212,10 +220,6 @@ class BatchedSerializer(Serializer):
def _load_stream_without_unbatching(self, stream):
return self.serializer.load_stream(stream)
def __eq__(self, other):
return (isinstance(other, BatchedSerializer) and
other.serializer == self.serializer and other.batchSize == self.batchSize)
def __repr__(self):
return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
@ -233,14 +237,14 @@ class FlattenedValuesSerializer(BatchedSerializer):
def _batched(self, iterator):
n = self.batchSize
for key, values in iterator:
for i in xrange(0, len(values), n):
for i in range(0, len(values), n):
yield key, values[i:i + n]
def load_stream(self, stream):
return self.serializer.load_stream(stream)
def __repr__(self):
return "FlattenedValuesSerializer(%d)" % self.batchSize
return "FlattenedValuesSerializer(%s, %d)" % (self.serializer, self.batchSize)
class AutoBatchedSerializer(BatchedSerializer):
@ -270,12 +274,8 @@ class AutoBatchedSerializer(BatchedSerializer):
elif size > best * 10 and batch > 1:
batch /= 2
def __eq__(self, other):
return (isinstance(other, AutoBatchedSerializer) and
other.serializer == self.serializer and other.bestSize == self.bestSize)
def __repr__(self):
return "AutoBatchedSerializer(%s)" % str(self.serializer)
return "AutoBatchedSerializer(%s)" % self.serializer
class CartesianDeserializer(FramedSerializer):
@ -285,6 +285,7 @@ class CartesianDeserializer(FramedSerializer):
"""
def __init__(self, key_ser, val_ser):
FramedSerializer.__init__(self)
self.key_ser = key_ser
self.val_ser = val_ser
@ -293,7 +294,7 @@ class CartesianDeserializer(FramedSerializer):
val_stream = self.val_ser._load_stream_without_unbatching(stream)
key_is_batched = isinstance(self.key_ser, BatchedSerializer)
val_is_batched = isinstance(self.val_ser, BatchedSerializer)
for (keys, vals) in izip(key_stream, val_stream):
for (keys, vals) in zip(key_stream, val_stream):
keys = keys if key_is_batched else [keys]
vals = vals if val_is_batched else [vals]
yield (keys, vals)
@ -303,10 +304,6 @@ class CartesianDeserializer(FramedSerializer):
for pair in product(keys, vals):
yield pair
def __eq__(self, other):
return (isinstance(other, CartesianDeserializer) and
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
def __repr__(self):
return "CartesianDeserializer(%s, %s)" % \
(str(self.key_ser), str(self.val_ser))
@ -318,22 +315,14 @@ class PairDeserializer(CartesianDeserializer):
Deserializes the JavaRDD zip() of two PythonRDDs.
"""
def __init__(self, key_ser, val_ser):
self.key_ser = key_ser
self.val_ser = val_ser
def load_stream(self, stream):
for (keys, vals) in self.prepare_keys_values(stream):
if len(keys) != len(vals):
raise ValueError("Can not deserialize RDD with different number of items"
" in pair: (%d, %d)" % (len(keys), len(vals)))
for pair in izip(keys, vals):
for pair in zip(keys, vals):
yield pair
def __eq__(self, other):
return (isinstance(other, PairDeserializer) and
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
def __repr__(self):
return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))
@ -382,8 +371,8 @@ def _hijack_namedtuple():
global _old_namedtuple # or it will put in closure
def _copy_func(f):
return types.FunctionType(f.func_code, f.func_globals, f.func_name,
f.func_defaults, f.func_closure)
return types.FunctionType(f.__code__, f.__globals__, f.__name__,
f.__defaults__, f.__closure__)
_old_namedtuple = _copy_func(collections.namedtuple)
@ -392,15 +381,15 @@ def _hijack_namedtuple():
return _hack_namedtuple(cls)
# replace namedtuple with new one
collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple
collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple
collections.namedtuple.func_code = namedtuple.func_code
collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple
collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple
collections.namedtuple.__code__ = namedtuple.__code__
collections.namedtuple.__hijack = 1
# hack the cls already generated by namedtuple
# those created in other module can be pickled as normal,
# so only hack those in __main__ module
for n, o in sys.modules["__main__"].__dict__.iteritems():
for n, o in sys.modules["__main__"].__dict__.items():
if (type(o) is type and o.__base__ is tuple
and hasattr(o, "_fields")
and "__reduce__" not in o.__dict__):
@ -413,7 +402,7 @@ _hijack_namedtuple()
class PickleSerializer(FramedSerializer):
"""
Serializes objects using Python's cPickle serializer:
Serializes objects using Python's pickle serializer:
http://docs.python.org/2/library/pickle.html
@ -422,10 +411,14 @@ class PickleSerializer(FramedSerializer):
"""
def dumps(self, obj):
return cPickle.dumps(obj, 2)
return pickle.dumps(obj, protocol)
def loads(self, obj):
return cPickle.loads(obj)
if sys.version >= '3':
def loads(self, obj, encoding="bytes"):
return pickle.loads(obj, encoding=encoding)
else:
def loads(self, obj, encoding=None):
return pickle.loads(obj)
class CloudPickleSerializer(PickleSerializer):
@ -454,7 +447,7 @@ class MarshalSerializer(FramedSerializer):
class AutoSerializer(FramedSerializer):
"""
Choose marshal or cPickle as serialization protocol automatically
Choose marshal or pickle as serialization protocol automatically
"""
def __init__(self):
@ -463,19 +456,19 @@ class AutoSerializer(FramedSerializer):
def dumps(self, obj):
if self._type is not None:
return 'P' + cPickle.dumps(obj, -1)
return b'P' + pickle.dumps(obj, -1)
try:
return 'M' + marshal.dumps(obj)
return b'M' + marshal.dumps(obj)
except Exception:
self._type = 'P'
return 'P' + cPickle.dumps(obj, -1)
self._type = b'P'
return b'P' + pickle.dumps(obj, -1)
def loads(self, obj):
_type = obj[0]
if _type == 'M':
if _type == b'M':
return marshal.loads(obj[1:])
elif _type == 'P':
return cPickle.loads(obj[1:])
elif _type == b'P':
return pickle.loads(obj[1:])
else:
raise ValueError("invalid sevialization type: %s" % _type)
@ -495,8 +488,8 @@ class CompressedSerializer(FramedSerializer):
def loads(self, obj):
return self.serializer.loads(zlib.decompress(obj))
def __eq__(self, other):
return isinstance(other, CompressedSerializer) and self.serializer == other.serializer
def __repr__(self):
return "CompressedSerializer(%s)" % self.serializer
class UTF8Deserializer(Serializer):
@ -505,7 +498,7 @@ class UTF8Deserializer(Serializer):
Deserializes streams written by String.getBytes.
"""
def __init__(self, use_unicode=False):
def __init__(self, use_unicode=True):
self.use_unicode = use_unicode
def loads(self, stream):
@ -526,13 +519,13 @@ class UTF8Deserializer(Serializer):
except EOFError:
return
def __eq__(self, other):
return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode
def __repr__(self):
return "UTF8Deserializer(%s)" % self.use_unicode
def read_long(stream):
length = stream.read(8)
if length == "":
if not length:
raise EOFError
return struct.unpack("!q", length)[0]
@ -547,7 +540,7 @@ def pack_long(value):
def read_int(stream):
length = stream.read(4)
if length == "":
if not length:
raise EOFError
return struct.unpack("!i", length)[0]

View file

@ -21,13 +21,6 @@ An interactive shell.
This file is designed to be launched as a PYTHONSTARTUP script.
"""
import sys
if sys.version_info[0] != 2:
print("Error: Default Python used is Python%s" % sys.version_info.major)
print("\tSet env variable PYSPARK_PYTHON to Python2 binary and re-run it.")
sys.exit(1)
import atexit
import os
import platform
@ -53,9 +46,14 @@ atexit.register(lambda: sc.stop())
try:
# Try to access HiveConf, it will raise exception if Hive is not added
sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
sqlCtx = sqlContext = HiveContext(sc)
sqlContext = HiveContext(sc)
except py4j.protocol.Py4JError:
sqlCtx = sqlContext = SQLContext(sc)
sqlContext = SQLContext(sc)
except TypeError:
sqlContext = SQLContext(sc)
# for compatibility
sqlCtx = sqlContext
print("""Welcome to
____ __

View file

@ -78,8 +78,8 @@ def _get_local_dirs(sub):
# global stats
MemoryBytesSpilled = 0L
DiskBytesSpilled = 0L
MemoryBytesSpilled = 0
DiskBytesSpilled = 0
class Aggregator(object):
@ -126,7 +126,7 @@ class Merger(object):
""" Merge the combined items by mergeCombiner """
raise NotImplementedError
def iteritems(self):
def items(self):
""" Return the merged items ad iterator """
raise NotImplementedError
@ -156,9 +156,9 @@ class InMemoryMerger(Merger):
for k, v in iterator:
d[k] = comb(d[k], v) if k in d else v
def iteritems(self):
""" Return the merged items as iterator """
return self.data.iteritems()
def items(self):
""" Return the merged items ad iterator """
return iter(self.data.items())
def _compressed_serializer(self, serializer=None):
@ -208,15 +208,15 @@ class ExternalMerger(Merger):
>>> agg = SimpleAggregator(lambda x, y: x + y)
>>> merger = ExternalMerger(agg, 10)
>>> N = 10000
>>> merger.mergeValues(zip(xrange(N), xrange(N)))
>>> merger.mergeValues(zip(range(N), range(N)))
>>> assert merger.spills > 0
>>> sum(v for k,v in merger.iteritems())
>>> sum(v for k,v in merger.items())
49995000
>>> merger = ExternalMerger(agg, 10)
>>> merger.mergeCombiners(zip(xrange(N), xrange(N)))
>>> merger.mergeCombiners(zip(range(N), range(N)))
>>> assert merger.spills > 0
>>> sum(v for k,v in merger.iteritems())
>>> sum(v for k,v in merger.items())
49995000
"""
@ -335,10 +335,10 @@ class ExternalMerger(Merger):
# above limit at the first time.
# open all the files for writing
streams = [open(os.path.join(path, str(i)), 'w')
streams = [open(os.path.join(path, str(i)), 'wb')
for i in range(self.partitions)]
for k, v in self.data.iteritems():
for k, v in self.data.items():
h = self._partition(k)
# put one item in batch, make it compatible with load_stream
# it will increase the memory if dump them in batch
@ -354,9 +354,9 @@ class ExternalMerger(Merger):
else:
for i in range(self.partitions):
p = os.path.join(path, str(i))
with open(p, "w") as f:
with open(p, "wb") as f:
# dump items in batch
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.serializer.dump_stream(iter(self.pdata[i].items()), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)
@ -364,10 +364,10 @@ class ExternalMerger(Merger):
gc.collect() # release the memory as much as possible
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
def iteritems(self):
def items(self):
""" Return all merged items as iterator """
if not self.pdata and not self.spills:
return self.data.iteritems()
return iter(self.data.items())
return self._external_items()
def _external_items(self):
@ -398,7 +398,8 @@ class ExternalMerger(Merger):
path = self._get_spill_dir(j)
p = os.path.join(path, str(index))
# do not check memory during merging
self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
with open(p, "rb") as f:
self.mergeCombiners(self.serializer.load_stream(f), 0)
# limit the total partitions
if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
@ -408,7 +409,7 @@ class ExternalMerger(Merger):
gc.collect() # release the memory as much as possible
return self._recursive_merged_items(index)
return self.data.iteritems()
return self.data.items()
def _recursive_merged_items(self, index):
"""
@ -426,7 +427,8 @@ class ExternalMerger(Merger):
for j in range(self.spills):
path = self._get_spill_dir(j)
p = os.path.join(path, str(index))
m.mergeCombiners(self.serializer.load_stream(open(p)), 0)
with open(p, 'rb') as f:
m.mergeCombiners(self.serializer.load_stream(f), 0)
if get_used_memory() > limit:
m._spill()
@ -451,7 +453,7 @@ class ExternalSorter(object):
>>> sorter = ExternalSorter(1) # 1M
>>> import random
>>> l = range(1024)
>>> l = list(range(1024))
>>> random.shuffle(l)
>>> sorted(l) == list(sorter.sorted(l))
True
@ -499,9 +501,16 @@ class ExternalSorter(object):
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
with open(path, 'wb') as f:
self.serializer.dump_stream(current_chunk, f)
chunks.append(self.serializer.load_stream(open(path)))
def load(f):
for v in self.serializer.load_stream(f):
yield v
# close the file explicit once we consume all the items
# to avoid ResourceWarning in Python3
f.close()
chunks.append(load(open(path, 'rb')))
current_chunk = []
gc.collect()
limit = self._next_limit()
@ -527,7 +536,7 @@ class ExternalList(object):
ExternalList can have many items which cannot be hold in memory in
the same time.
>>> l = ExternalList(range(100))
>>> l = ExternalList(list(range(100)))
>>> len(l)
100
>>> l.append(10)
@ -555,11 +564,11 @@ class ExternalList(object):
def __getstate__(self):
if self._file is not None:
self._file.flush()
f = os.fdopen(os.dup(self._file.fileno()))
f.seek(0)
serialized = f.read()
with os.fdopen(os.dup(self._file.fileno()), "rb") as f:
f.seek(0)
serialized = f.read()
else:
serialized = ''
serialized = b''
return self.values, self.count, serialized
def __setstate__(self, item):
@ -575,7 +584,7 @@ class ExternalList(object):
if self._file is not None:
self._file.flush()
# read all items from disks first
with os.fdopen(os.dup(self._file.fileno()), 'r') as f:
with os.fdopen(os.dup(self._file.fileno()), 'rb') as f:
f.seek(0)
for v in self._ser.load_stream(f):
yield v
@ -598,11 +607,16 @@ class ExternalList(object):
d = dirs[id(self) % len(dirs)]
if not os.path.exists(d):
os.makedirs(d)
p = os.path.join(d, str(id))
self._file = open(p, "w+", 65536)
p = os.path.join(d, str(id(self)))
self._file = open(p, "wb+", 65536)
self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
os.unlink(p)
def __del__(self):
if self._file:
self._file.close()
self._file = None
def _spill(self):
""" dump the values into disk """
global MemoryBytesSpilled, DiskBytesSpilled
@ -651,33 +665,28 @@ class GroupByKey(object):
"""
Group a sorted iterator as [(k1, it1), (k2, it2), ...]
>>> k = [i/3 for i in range(6)]
>>> k = [i // 3 for i in range(6)]
>>> v = [[i] for i in range(6)]
>>> g = GroupByKey(iter(zip(k, v)))
>>> g = GroupByKey(zip(k, v))
>>> [(k, list(it)) for k, it in g]
[(0, [0, 1, 2]), (1, [3, 4, 5])]
"""
def __init__(self, iterator):
self.iterator = iter(iterator)
self.next_item = None
self.iterator = iterator
def __iter__(self):
return self
def next(self):
key, value = self.next_item if self.next_item else next(self.iterator)
values = ExternalListOfList([value])
try:
while True:
k, v = next(self.iterator)
if k != key:
self.next_item = (k, v)
break
key, values = None, None
for k, v in self.iterator:
if values is not None and k == key:
values.append(v)
except StopIteration:
self.next_item = None
return key, values
else:
if values is not None:
yield (key, values)
key = k
values = ExternalListOfList([v])
if values is not None:
yield (key, values)
class ExternalGroupBy(ExternalMerger):
@ -744,7 +753,7 @@ class ExternalGroupBy(ExternalMerger):
# above limit at the first time.
# open all the files for writing
streams = [open(os.path.join(path, str(i)), 'w')
streams = [open(os.path.join(path, str(i)), 'wb')
for i in range(self.partitions)]
# If the number of keys is small, then the overhead of sort is small
@ -756,7 +765,7 @@ class ExternalGroupBy(ExternalMerger):
h = self._partition(k)
self.serializer.dump_stream([(k, self.data[k])], streams[h])
else:
for k, v in self.data.iteritems():
for k, v in self.data.items():
h = self._partition(k)
self.serializer.dump_stream([(k, v)], streams[h])
@ -771,14 +780,14 @@ class ExternalGroupBy(ExternalMerger):
else:
for i in range(self.partitions):
p = os.path.join(path, str(i))
with open(p, "w") as f:
with open(p, "wb") as f:
# dump items in batch
if self._sorted:
# sort by key only (stable)
sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0))
sorted_items = sorted(self.pdata[i].items(), key=operator.itemgetter(0))
self.serializer.dump_stream(sorted_items, f)
else:
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.serializer.dump_stream(self.pdata[i].items(), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)
@ -792,7 +801,7 @@ class ExternalGroupBy(ExternalMerger):
# if the memory can not hold all the partition,
# then use sort based merge. Because of compression,
# the data on disks will be much smaller than needed memory
if (size >> 20) >= self.memory_limit / 10:
if size >= self.memory_limit << 17: # * 1M / 8
return self._merge_sorted_items(index)
self.data = {}
@ -800,15 +809,18 @@ class ExternalGroupBy(ExternalMerger):
path = self._get_spill_dir(j)
p = os.path.join(path, str(index))
# do not check memory during merging
self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
return self.data.iteritems()
with open(p, "rb") as f:
self.mergeCombiners(self.serializer.load_stream(f), 0)
return self.data.items()
def _merge_sorted_items(self, index):
""" load a partition from disk, then sort and group by key """
def load_partition(j):
path = self._get_spill_dir(j)
p = os.path.join(path, str(index))
return self.serializer.load_stream(open(p, 'r', 65536))
with open(p, 'rb', 65536) as f:
for v in self.serializer.load_stream(f):
yield v
disk_items = [load_partition(j) for j in range(self.spills)]

View file

@ -37,9 +37,22 @@ Important classes of Spark SQL and DataFrames:
- L{types}
List of data types available.
"""
from __future__ import absolute_import
# fix the module name conflict for Python 3+
import sys
from . import _types as types
modname = __name__ + '.types'
types.__name__ = modname
# update the __module__ for all objects, make them picklable
for v in types.__dict__.values():
if hasattr(v, "__module__") and v.__module__.endswith('._types'):
v.__module__ = modname
sys.modules[modname] = types
del modname, sys
from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.types import Row
from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions
__all__ = [

View file

@ -15,6 +15,7 @@
# limitations under the License.
#
import sys
import decimal
import datetime
import keyword
@ -25,6 +26,9 @@ import weakref
from array import array
from operator import itemgetter
if sys.version >= "3":
long = int
unicode = str
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
@ -410,7 +414,7 @@ class UserDefinedType(DataType):
split = pyUDT.rfind(".")
pyModule = pyUDT[:split]
pyClass = pyUDT[split+1:]
m = __import__(pyModule, globals(), locals(), [pyClass], -1)
m = __import__(pyModule, globals(), locals(), [pyClass])
UDT = getattr(m, pyClass)
return UDT()
@ -419,10 +423,9 @@ class UserDefinedType(DataType):
_all_primitive_types = dict((v.typeName(), v)
for v in globals().itervalues()
if type(v) is PrimitiveTypeSingleton and
v.__base__ == PrimitiveType)
for v in list(globals().values())
if (type(v) is type or type(v) is PrimitiveTypeSingleton)
and v.__base__ == PrimitiveType)
_all_complex_types = dict((v.typeName(), v)
for v in [ArrayType, MapType, StructType])
@ -486,10 +489,10 @@ _FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
def _parse_datatype_json_value(json_value):
if type(json_value) is unicode:
if not isinstance(json_value, dict):
if json_value in _all_primitive_types.keys():
return _all_primitive_types[json_value]()
elif json_value == u'decimal':
elif json_value == 'decimal':
return DecimalType()
elif _FIXED_DECIMAL.match(json_value):
m = _FIXED_DECIMAL.match(json_value)
@ -511,10 +514,8 @@ _type_mappings = {
type(None): NullType,
bool: BooleanType,
int: LongType,
long: LongType,
float: DoubleType,
str: StringType,
unicode: StringType,
bytearray: BinaryType,
decimal.Decimal: DecimalType,
datetime.date: DateType,
@ -522,6 +523,12 @@ _type_mappings = {
datetime.time: TimestampType,
}
if sys.version < "3":
_type_mappings.update({
unicode: StringType,
long: LongType,
})
def _infer_type(obj):
"""Infer the DataType from obj
@ -541,7 +548,7 @@ def _infer_type(obj):
return dataType()
if isinstance(obj, dict):
for key, value in obj.iteritems():
for key, value in obj.items():
if key is not None and value is not None:
return MapType(_infer_type(key), _infer_type(value), True)
else:
@ -565,10 +572,10 @@ def _infer_schema(row):
items = sorted(row.items())
elif isinstance(row, (tuple, list)):
if hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
elif hasattr(row, "__fields__"): # Row
if hasattr(row, "__fields__"): # Row
items = zip(row.__fields__, tuple(row))
elif hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
else:
names = ['_%d' % i for i in range(1, len(row) + 1)]
items = zip(names, row)
@ -647,7 +654,7 @@ def _python_to_sql_converter(dataType):
if isinstance(obj, dict):
return tuple(c(obj.get(n)) for n, c in zip(names, converters))
elif isinstance(obj, tuple):
if hasattr(obj, "_fields") or hasattr(obj, "__fields__"):
if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
return tuple(c(v) for c, v in zip(converters, obj))
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
d = dict(obj)
@ -733,12 +740,12 @@ def _create_converter(dataType):
if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
return lambda row: [conv(v) for v in row]
elif isinstance(dataType, MapType):
kconv = _create_converter(dataType.keyType)
vconv = _create_converter(dataType.valueType)
return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems())
return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
elif isinstance(dataType, NullType):
return lambda x: None
@ -881,7 +888,7 @@ def _infer_schema_type(obj, dataType):
>>> _infer_schema_type(row, schema)
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
if dataType is NullType():
if isinstance(dataType, NullType):
return _infer_type(obj)
if not obj:
@ -892,7 +899,7 @@ def _infer_schema_type(obj, dataType):
return ArrayType(eType, True)
elif isinstance(dataType, MapType):
k, v = obj.iteritems().next()
k, v = next(iter(obj.items()))
return MapType(_infer_schema_type(k, dataType.keyType),
_infer_schema_type(v, dataType.valueType))
@ -935,7 +942,7 @@ def _verify_type(obj, dataType):
>>> _verify_type(None, StructType([]))
>>> _verify_type("", StringType())
>>> _verify_type(0, LongType())
>>> _verify_type(range(3), ArrayType(ShortType()))
>>> _verify_type(list(range(3)), ArrayType(ShortType()))
>>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
@ -976,7 +983,7 @@ def _verify_type(obj, dataType):
_verify_type(i, dataType.elementType)
elif isinstance(dataType, MapType):
for k, v in obj.iteritems():
for k, v in obj.items():
_verify_type(k, dataType.keyType)
_verify_type(v, dataType.valueType)
@ -1213,6 +1220,8 @@ class Row(tuple):
return self[idx]
except IndexError:
raise AttributeError(item)
except ValueError:
raise AttributeError(item)
def __reduce__(self):
if hasattr(self, "__fields__"):

View file

@ -15,14 +15,19 @@
# limitations under the License.
#
import sys
import warnings
import json
from itertools import imap
if sys.version >= '3':
basestring = unicode = str
else:
from itertools import imap as map
from py4j.protocol import Py4JError
from py4j.java_collections import MapConverter
from pyspark.rdd import RDD, _prepare_for_python_RDD
from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
@ -62,31 +67,27 @@ class SQLContext(object):
A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
tables, execute SQL over tables, cache tables, and read parquet files.
When created, :class:`SQLContext` adds a method called ``toDF`` to :class:`RDD`,
which could be used to convert an RDD into a DataFrame, it's a shorthand for
:func:`SQLContext.createDataFrame`.
:param sparkContext: The :class:`SparkContext` backing this SQLContext.
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new
SQLContext in the JVM, instead we make all calls to this object.
"""
@ignore_unicode_prefix
def __init__(self, sparkContext, sqlContext=None):
"""Creates a new SQLContext.
>>> from datetime import datetime
>>> sqlContext = SQLContext(sc)
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
>>> df = allTypes.toDF()
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
... x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
self._sc = sparkContext
self._jsc = self._sc._jsc
@ -122,6 +123,7 @@ class SQLContext(object):
"""Returns a :class:`UDFRegistration` for UDF registration."""
return UDFRegistration(self)
@ignore_unicode_prefix
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
@ -147,7 +149,7 @@ class SQLContext(object):
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
@ -185,6 +187,7 @@ class SQLContext(object):
schema = rdd.map(_infer_schema).reduce(_merge_type)
return schema
@ignore_unicode_prefix
def inferSchema(self, rdd, samplingRatio=None):
"""::note: Deprecated in 1.3, use :func:`createDataFrame` instead.
"""
@ -195,6 +198,7 @@ class SQLContext(object):
return self.createDataFrame(rdd, None, samplingRatio)
@ignore_unicode_prefix
def applySchema(self, rdd, schema):
"""::note: Deprecated in 1.3, use :func:`createDataFrame` instead.
"""
@ -208,6 +212,7 @@ class SQLContext(object):
return self.createDataFrame(rdd, schema)
@ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None):
"""
Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`,
@ -380,6 +385,7 @@ class SQLContext(object):
df = self._ssql_ctx.jsonFile(path, scala_datatype)
return DataFrame(df, self)
@ignore_unicode_prefix
def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
"""Loads an RDD storing one JSON object per string as a :class:`DataFrame`.
@ -477,6 +483,7 @@ class SQLContext(object):
joptions)
return DataFrame(df, self)
@ignore_unicode_prefix
def sql(self, sqlQuery):
"""Returns a :class:`DataFrame` representing the result of the given query.
@ -497,6 +504,7 @@ class SQLContext(object):
"""
return DataFrame(self._ssql_ctx.table(tableName), self)
@ignore_unicode_prefix
def tables(self, dbName=None):
"""Returns a :class:`DataFrame` containing names of tables in the given database.

View file

@ -16,14 +16,19 @@
#
import sys
import itertools
import warnings
import random
if sys.version >= '3':
basestring = unicode = str
long = int
else:
from itertools import imap as map
from py4j.java_collections import ListConverter, MapConverter
from pyspark.context import SparkContext
from pyspark.rdd import RDD, _load_from_socket
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@ -65,19 +70,20 @@ class DataFrame(object):
self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
self._schema = None # initialized lazily
self._lazy_rdd = None
@property
def rdd(self):
"""Returns the content as an :class:`pyspark.RDD` of :class:`Row`.
"""
if not hasattr(self, '_lazy_rdd'):
if self._lazy_rdd is None:
jrdd = self._jdf.javaToPython()
rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
schema = self.schema
def applySchema(it):
cls = _create_cls(schema)
return itertools.imap(cls, it)
return map(cls, it)
self._lazy_rdd = rdd.mapPartitions(applySchema)
@ -89,13 +95,14 @@ class DataFrame(object):
"""
return DataFrameNaFunctions(self)
def toJSON(self, use_unicode=False):
@ignore_unicode_prefix
def toJSON(self, use_unicode=True):
"""Converts a :class:`DataFrame` into a :class:`RDD` of string.
Each row is turned into a JSON document as one element in the returned RDD.
>>> df.toJSON().first()
'{"age":2,"name":"Alice"}'
u'{"age":2,"name":"Alice"}'
"""
rdd = self._jdf.toJSON()
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
@ -228,7 +235,7 @@ class DataFrame(object):
|-- name: string (nullable = true)
<BLANKLINE>
"""
print (self._jdf.schema().treeString())
print(self._jdf.schema().treeString())
def explain(self, extended=False):
"""Prints the (logical and physical) plans to the console for debugging purpose.
@ -250,9 +257,9 @@ class DataFrame(object):
== RDD ==
"""
if extended:
print self._jdf.queryExecution().toString()
print(self._jdf.queryExecution().toString())
else:
print self._jdf.queryExecution().executedPlan().toString()
print(self._jdf.queryExecution().executedPlan().toString())
def isLocal(self):
"""Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally
@ -270,7 +277,7 @@ class DataFrame(object):
2 Alice
5 Bob
"""
print self._jdf.showString(n).encode('utf8', 'ignore')
print(self._jdf.showString(n))
def __repr__(self):
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
@ -279,10 +286,11 @@ class DataFrame(object):
"""Returns the number of rows in this :class:`DataFrame`.
>>> df.count()
2L
2
"""
return self._jdf.count()
return int(self._jdf.count())
@ignore_unicode_prefix
def collect(self):
"""Returns all the records as a list of :class:`Row`.
@ -295,6 +303,7 @@ class DataFrame(object):
cls = _create_cls(self.schema)
return [cls(r) for r in rs]
@ignore_unicode_prefix
def limit(self, num):
"""Limits the result count to the number specified.
@ -306,6 +315,7 @@ class DataFrame(object):
jdf = self._jdf.limit(num)
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
def take(self, num):
"""Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
@ -314,6 +324,7 @@ class DataFrame(object):
"""
return self.limit(num).collect()
@ignore_unicode_prefix
def map(self, f):
""" Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`.
@ -324,6 +335,7 @@ class DataFrame(object):
"""
return self.rdd.map(f)
@ignore_unicode_prefix
def flatMap(self, f):
""" Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`,
and then flattening the results.
@ -353,7 +365,7 @@ class DataFrame(object):
This is a shorthand for ``df.rdd.foreach()``.
>>> def f(person):
... print person.name
... print(person.name)
>>> df.foreach(f)
"""
return self.rdd.foreach(f)
@ -365,7 +377,7 @@ class DataFrame(object):
>>> def f(people):
... for person in people:
... print person.name
... print(person.name)
>>> df.foreachPartition(f)
"""
return self.rdd.foreachPartition(f)
@ -412,7 +424,7 @@ class DataFrame(object):
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
>>> df.distinct().count()
2L
2
"""
return DataFrame(self._jdf.distinct(), self.sql_ctx)
@ -420,10 +432,10 @@ class DataFrame(object):
"""Returns a sampled subset of this :class:`DataFrame`.
>>> df.sample(False, 0.5, 97).count()
1L
1
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
seed = seed if seed is not None else random.randint(0, sys.maxint)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
@ -437,6 +449,7 @@ class DataFrame(object):
return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
@property
@ignore_unicode_prefix
def columns(self):
"""Returns all column names as a list.
@ -445,6 +458,7 @@ class DataFrame(object):
"""
return [f.name for f in self.schema.fields]
@ignore_unicode_prefix
def join(self, other, joinExprs=None, joinType=None):
"""Joins with another :class:`DataFrame`, using the given join expression.
@ -470,6 +484,7 @@ class DataFrame(object):
jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
def sort(self, *cols):
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
@ -513,6 +528,7 @@ class DataFrame(object):
jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
def head(self, n=None):
"""
Returns the first ``n`` rows as a list of :class:`Row`,
@ -528,6 +544,7 @@ class DataFrame(object):
return rs[0] if rs else None
return self.take(n)
@ignore_unicode_prefix
def first(self):
"""Returns the first row as a :class:`Row`.
@ -536,6 +553,7 @@ class DataFrame(object):
"""
return self.head()
@ignore_unicode_prefix
def __getitem__(self, item):
"""Returns the column as a :class:`Column`.
@ -567,6 +585,7 @@ class DataFrame(object):
jc = self._jdf.apply(name)
return Column(jc)
@ignore_unicode_prefix
def select(self, *cols):
"""Projects a set of expressions and returns a new :class:`DataFrame`.
@ -598,6 +617,7 @@ class DataFrame(object):
jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
def filter(self, condition):
"""Filters rows using the given condition.
@ -626,6 +646,7 @@ class DataFrame(object):
where = filter
@ignore_unicode_prefix
def groupBy(self, *cols):
"""Groups the :class:`DataFrame` using the specified columns,
so we can run aggregation on them. See :class:`GroupedData`
@ -775,6 +796,7 @@ class DataFrame(object):
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
@ignore_unicode_prefix
def withColumn(self, colName, col):
"""Returns a new :class:`DataFrame` by adding a column.
@ -786,6 +808,7 @@ class DataFrame(object):
"""
return self.select('*', col.alias(colName))
@ignore_unicode_prefix
def withColumnRenamed(self, existing, new):
"""REturns a new :class:`DataFrame` by renaming an existing column.
@ -852,6 +875,7 @@ class GroupedData(object):
self._jdf = jdf
self.sql_ctx = sql_ctx
@ignore_unicode_prefix
def agg(self, *exprs):
"""Compute aggregates and returns the result as a :class:`DataFrame`.
@ -1041,11 +1065,13 @@ class Column(object):
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
__div__ = _bin_op("divide")
__truediv__ = _bin_op("divide")
__mod__ = _bin_op("mod")
__radd__ = _bin_op("plus")
__rsub__ = _reverse_op("minus")
__rmul__ = _bin_op("multiply")
__rdiv__ = _reverse_op("divide")
__rtruediv__ = _reverse_op("divide")
__rmod__ = _reverse_op("mod")
# logistic operators
@ -1075,6 +1101,7 @@ class Column(object):
startswith = _bin_op("startsWith")
endswith = _bin_op("endsWith")
@ignore_unicode_prefix
def substr(self, startPos, length):
"""
Return a :class:`Column` which is a substring of the column
@ -1097,6 +1124,7 @@ class Column(object):
__getslice__ = substr
@ignore_unicode_prefix
def inSet(self, *cols):
""" A boolean expression that is evaluated to true if the value of this
expression is contained by the evaluated values of the arguments.
@ -1131,6 +1159,7 @@ class Column(object):
"""
return Column(getattr(self._jc, "as")(alias))
@ignore_unicode_prefix
def cast(self, dataType):
""" Convert the column into type `dataType`

View file

@ -18,8 +18,10 @@
"""
A collections of builtin functions
"""
import sys
from itertools import imap
if sys.version < "3":
from itertools import imap as map
from py4j.java_collections import ListConverter
@ -116,7 +118,7 @@ class UserDefinedFunction(object):
def _create_judf(self):
f = self.func # put it in closure `func`
func = lambda _, it: imap(lambda x: f(*x), it)
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
sc = SparkContext._active_spark_context

View file

@ -157,13 +157,13 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(4, res[0])
def test_udf_with_array_type(self):
d = [Row(l=range(3), d={"key": range(5)})]
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
self.assertEqual(range(3), l1)
self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)
def test_broadcast_in_udf(self):
@ -266,7 +266,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_apply_schema(self):
from datetime import date, datetime
rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
{"a": 1}, (2,), [1, 2, 3], None)])
schema = StructType([
@ -309,7 +309,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
df = self.sc.parallelize(d).toDF()
k, v = df.head().m.items()[0]
k, v = list(df.head().m.items())[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@ -554,6 +554,9 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
except py4j.protocol.Py4JError:
cls.sqlCtx = None
return
except TypeError:
cls.sqlCtx = None
return
os.unlink(cls.tempdir.name)
_scala_HiveContext =\
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())

View file

@ -31,7 +31,7 @@ except ImportError:
class StatCounter(object):
def __init__(self, values=[]):
self.n = 0L # Running count of our values
self.n = 0 # Running count of our values
self.mu = 0.0 # Running mean of our values
self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2)
self.maxValue = float("-inf")
@ -87,7 +87,7 @@ class StatCounter(object):
return copy.deepcopy(self)
def count(self):
return self.n
return int(self.n)
def mean(self):
return self.mu

View file

@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import print_function
import os
import sys
@ -157,7 +160,7 @@ class StreamingContext(object):
try:
jssc = gw.jvm.JavaStreamingContext(checkpointPath)
except Exception:
print >>sys.stderr, "failed to load StreamingContext from checkpoint"
print("failed to load StreamingContext from checkpoint", file=sys.stderr)
raise
jsc = jssc.sparkContext()

View file

@ -15,11 +15,15 @@
# limitations under the License.
#
from itertools import chain, ifilter, imap
import sys
import operator
import time
from itertools import chain
from datetime import datetime
if sys.version < "3":
from itertools import imap as map, ifilter as filter
from py4j.protocol import Py4JJavaError
from pyspark import RDD
@ -76,7 +80,7 @@ class DStream(object):
Return a new DStream containing only the elements that satisfy predicate.
"""
def func(iterator):
return ifilter(f, iterator)
return filter(f, iterator)
return self.mapPartitions(func, True)
def flatMap(self, f, preservesPartitioning=False):
@ -85,7 +89,7 @@ class DStream(object):
this DStream, and then flattening the results
"""
def func(s, iterator):
return chain.from_iterable(imap(f, iterator))
return chain.from_iterable(map(f, iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)
def map(self, f, preservesPartitioning=False):
@ -93,7 +97,7 @@ class DStream(object):
Return a new DStream by applying a function to each element of DStream.
"""
def func(iterator):
return imap(f, iterator)
return map(f, iterator)
return self.mapPartitions(func, preservesPartitioning)
def mapPartitions(self, f, preservesPartitioning=False):
@ -150,7 +154,7 @@ class DStream(object):
"""
Apply a function to each RDD in this DStream.
"""
if func.func_code.co_argcount == 1:
if func.__code__.co_argcount == 1:
old_func = func
func = lambda t, rdd: old_func(rdd)
jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
@ -165,14 +169,14 @@ class DStream(object):
"""
def takeAndPrint(time, rdd):
taken = rdd.take(num + 1)
print "-------------------------------------------"
print "Time: %s" % time
print "-------------------------------------------"
print("-------------------------------------------")
print("Time: %s" % time)
print("-------------------------------------------")
for record in taken[:num]:
print record
print(record)
if len(taken) > num:
print "..."
print
print("...")
print()
self.foreachRDD(takeAndPrint)
@ -181,7 +185,7 @@ class DStream(object):
Return a new DStream by applying a map function to the value of
each key-value pairs in this DStream without changing the key.
"""
map_values_fn = lambda (k, v): (k, f(v))
map_values_fn = lambda kv: (kv[0], f(kv[1]))
return self.map(map_values_fn, preservesPartitioning=True)
def flatMapValues(self, f):
@ -189,7 +193,7 @@ class DStream(object):
Return a new DStream by applying a flatmap function to the value
of each key-value pairs in this DStream without changing the key.
"""
flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1]))
return self.flatMap(flat_map_fn, preservesPartitioning=True)
def glom(self):
@ -286,10 +290,10 @@ class DStream(object):
`func` can have one argument of `rdd`, or have two arguments of
(`time`, `rdd`)
"""
if func.func_code.co_argcount == 1:
if func.__code__.co_argcount == 1:
oldfunc = func
func = lambda t, rdd: oldfunc(rdd)
assert func.func_code.co_argcount == 2, "func should take one or two arguments"
assert func.__code__.co_argcount == 2, "func should take one or two arguments"
return TransformedDStream(self, func)
def transformWith(self, func, other, keepSerializer=False):
@ -300,10 +304,10 @@ class DStream(object):
`func` can have two arguments of (`rdd_a`, `rdd_b`) or have three
arguments of (`time`, `rdd_a`, `rdd_b`)
"""
if func.func_code.co_argcount == 2:
if func.__code__.co_argcount == 2:
oldfunc = func
func = lambda t, a, b: oldfunc(a, b)
assert func.func_code.co_argcount == 3, "func should take two or three arguments"
assert func.__code__.co_argcount == 3, "func should take two or three arguments"
jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer)
dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
other._jdstream.dstream(), jfunc)
@ -460,7 +464,7 @@ class DStream(object):
keyed = self.map(lambda x: (1, x))
reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc,
windowDuration, slideDuration, 1)
return reduced.map(lambda (k, v): v)
return reduced.map(lambda kv: kv[1])
def countByWindow(self, windowDuration, slideDuration):
"""
@ -489,7 +493,7 @@ class DStream(object):
keyed = self.map(lambda x: (x, 1))
counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub,
windowDuration, slideDuration, numPartitions)
return counted.filter(lambda (k, v): v > 0).count()
return counted.filter(lambda kv: kv[1] > 0).count()
def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None):
"""
@ -548,7 +552,8 @@ class DStream(object):
def invReduceFunc(t, a, b):
b = b.reduceByKey(func, numPartitions)
joined = a.leftOuterJoin(b, numPartitions)
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
return joined.mapValues(lambda kv: invFunc(kv[0], kv[1])
if kv[1] is not None else kv[0])
jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer)
if invReduceFunc:
@ -579,9 +584,9 @@ class DStream(object):
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
else:
g = a.cogroup(b.partitionBy(numPartitions), numPartitions)
g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
return state.filter(lambda (k, v): v is not None)
g = g.mapValues(lambda ab: (list(ab[1]), list(ab[0])[0] if len(ab[0]) else None))
state = g.mapValues(lambda vs_s: updateFunc(vs_s[0], vs_s[1]))
return state.filter(lambda k_v: k_v[1] is not None)
jreduceFunc = TransformFunction(self._sc, reduceFunc,
self._sc.serializer, self._jrdd_deserializer)

View file

@ -67,10 +67,10 @@ class KafkaUtils(object):
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
helper = helperClass.newInstance()
jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel)
except Py4JJavaError, e:
except Py4JJavaError as e:
# TODO: use --jar once it also work on driver
if 'ClassNotFoundException' in str(e.java_exception):
print """
print("""
________________________________________________________________________________________________
Spark Streaming's Kafka libraries not found in class path. Try one of the following.
@ -88,8 +88,8 @@ ________________________________________________________________________________
________________________________________________________________________________________________
""" % (ssc.sparkContext.version, ssc.sparkContext.version)
""" % (ssc.sparkContext.version, ssc.sparkContext.version))
raise e
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
stream = DStream(jstream, ssc, ser)
return stream.map(lambda (k, v): (keyDecoder(k), valueDecoder(v)))
return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))

View file

@ -22,6 +22,7 @@ import operator
import unittest
import tempfile
import struct
from functools import reduce
from py4j.java_collections import MapConverter
@ -51,7 +52,7 @@ class PySparkStreamingTestCase(unittest.TestCase):
while len(result) < n and time.time() - start_time < self.timeout:
time.sleep(0.01)
if len(result) < n:
print "timeout after", self.timeout
print("timeout after", self.timeout)
def _take(self, dstream, n):
"""
@ -131,7 +132,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
def func(dstream):
return dstream.map(str)
expected = map(lambda x: map(str, x), input)
expected = [list(map(str, x)) for x in input]
self._test_func(input, func, expected)
def test_flatMap(self):
@ -140,8 +141,8 @@ class BasicOperationTests(PySparkStreamingTestCase):
def func(dstream):
return dstream.flatMap(lambda x: (x, x * 2))
expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))),
input)
expected = [list(chain.from_iterable((map(lambda y: [y, y * 2], x))))
for x in input]
self._test_func(input, func, expected)
def test_filter(self):
@ -150,7 +151,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
def func(dstream):
return dstream.filter(lambda x: x % 2 == 0)
expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input)
expected = [[y for y in x if y % 2 == 0] for x in input]
self._test_func(input, func, expected)
def test_count(self):
@ -159,7 +160,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
def func(dstream):
return dstream.count()
expected = map(lambda x: [len(x)], input)
expected = [[len(x)] for x in input]
self._test_func(input, func, expected)
def test_reduce(self):
@ -168,7 +169,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
def func(dstream):
return dstream.reduce(operator.add)
expected = map(lambda x: [reduce(operator.add, x)], input)
expected = [[reduce(operator.add, x)] for x in input]
self._test_func(input, func, expected)
def test_reduceByKey(self):
@ -185,27 +186,27 @@ class BasicOperationTests(PySparkStreamingTestCase):
def test_mapValues(self):
"""Basic operation test for DStream.mapValues."""
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
[("", 4), (1, 1), (2, 2), (3, 3)],
[(0, 4), (1, 1), (2, 2), (3, 3)],
[(1, 1), (2, 1), (3, 1), (4, 1)]]
def func(dstream):
return dstream.mapValues(lambda x: x + 10)
expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)],
[("", 14), (1, 11), (2, 12), (3, 13)],
[(0, 14), (1, 11), (2, 12), (3, 13)],
[(1, 11), (2, 11), (3, 11), (4, 11)]]
self._test_func(input, func, expected, sort=True)
def test_flatMapValues(self):
"""Basic operation test for DStream.flatMapValues."""
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
[("", 4), (1, 1), (2, 1), (3, 1)],
[(0, 4), (1, 1), (2, 1), (3, 1)],
[(1, 1), (2, 1), (3, 1), (4, 1)]]
def func(dstream):
return dstream.flatMapValues(lambda x: (x, x + 10))
expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12),
("c", 1), ("c", 11), ("d", 1), ("d", 11)],
[("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
[(0, 4), (0, 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
[(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]]
self._test_func(input, func, expected)
@ -233,7 +234,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
def test_countByValue(self):
"""Basic operation test for DStream.countByValue."""
input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]]
input = [list(range(1, 5)) * 2, list(range(5, 7)) + list(range(5, 9)), ["a", "a", "b", ""]]
def func(dstream):
return dstream.countByValue()
@ -285,7 +286,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
def func(d1, d2):
return d1.union(d2)
expected = [range(6), range(6), range(6)]
expected = [list(range(6)), list(range(6)), list(range(6))]
self._test_func(input1, func, expected, input2=input2)
def test_cogroup(self):
@ -424,7 +425,7 @@ class StreamingContextTests(PySparkStreamingTestCase):
duration = 0.1
def _add_input_stream(self):
inputs = map(lambda x: range(1, x), range(101))
inputs = [range(1, x) for x in range(101)]
stream = self.ssc.queueStream(inputs)
self._collect(stream, 1, block=False)
@ -441,7 +442,7 @@ class StreamingContextTests(PySparkStreamingTestCase):
self.ssc.stop()
def test_queue_stream(self):
input = [range(i + 1) for i in range(3)]
input = [list(range(i + 1)) for i in range(3)]
dstream = self.ssc.queueStream(input)
result = self._collect(dstream, 3)
self.assertEqual(input, result)
@ -457,13 +458,13 @@ class StreamingContextTests(PySparkStreamingTestCase):
with open(os.path.join(d, name), "w") as f:
f.writelines(["%d\n" % i for i in range(10)])
self.wait_for(result, 2)
self.assertEqual([range(10), range(10)], result)
self.assertEqual([list(range(10)), list(range(10))], result)
def test_binary_records_stream(self):
d = tempfile.mkdtemp()
self.ssc = StreamingContext(self.sc, self.duration)
dstream = self.ssc.binaryRecordsStream(d, 10).map(
lambda v: struct.unpack("10b", str(v)))
lambda v: struct.unpack("10b", bytes(v)))
result = self._collect(dstream, 2, block=False)
self.ssc.start()
for name in ('a', 'b'):
@ -471,10 +472,10 @@ class StreamingContextTests(PySparkStreamingTestCase):
with open(os.path.join(d, name), "wb") as f:
f.write(bytearray(range(10)))
self.wait_for(result, 2)
self.assertEqual([range(10), range(10)], map(lambda v: list(v[0]), result))
self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result])
def test_union(self):
input = [range(i + 1) for i in range(3)]
input = [list(range(i + 1)) for i in range(3)]
dstream = self.ssc.queueStream(input)
dstream2 = self.ssc.queueStream(input)
dstream3 = self.ssc.union(dstream, dstream2)

View file

@ -91,9 +91,9 @@ class TransformFunctionSerializer(object):
except Exception:
traceback.print_exc()
def loads(self, bytes):
def loads(self, data):
try:
f, deserializers = self.serializer.loads(str(bytes))
f, deserializers = self.serializer.loads(bytes(data))
return TransformFunction(self.ctx, f, *deserializers)
except Exception:
traceback.print_exc()
@ -116,7 +116,7 @@ def rddToFileName(prefix, suffix, timestamp):
"""
if isinstance(timestamp, datetime):
seconds = time.mktime(timestamp.timetuple())
timestamp = long(seconds * 1000) + timestamp.microsecond / 1000
timestamp = int(seconds * 1000) + timestamp.microsecond // 1000
if suffix is None:
return prefix + "-" + str(timestamp)
else:

View file

@ -19,8 +19,8 @@
Unit tests for PySpark; additional tests are implemented as doctests in
individual modules.
"""
from array import array
from fileinput import input
from glob import glob
import os
import re
@ -45,6 +45,9 @@ if sys.version_info[:2] <= (2, 6):
sys.exit(1)
else:
import unittest
if sys.version_info[0] >= 3:
xrange = range
basestring = str
from pyspark.conf import SparkConf
@ -52,7 +55,9 @@ from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \
PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \
FlattenedValuesSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark import shuffle
from pyspark.profiler import BasicProfiler
@ -81,7 +86,7 @@ class MergerTests(unittest.TestCase):
def setUp(self):
self.N = 1 << 12
self.l = [i for i in xrange(self.N)]
self.data = zip(self.l, self.l)
self.data = list(zip(self.l, self.l))
self.agg = Aggregator(lambda x: [x],
lambda x, y: x.append(y) or x,
lambda x, y: x.extend(y) or x)
@ -89,45 +94,45 @@ class MergerTests(unittest.TestCase):
def test_in_memory(self):
m = InMemoryMerger(self.agg)
m.mergeValues(self.data)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
self.assertEqual(sum(sum(v) for k, v in m.items()),
sum(xrange(self.N)))
m = InMemoryMerger(self.agg)
m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data))
self.assertEqual(sum(sum(v) for k, v in m.items()),
sum(xrange(self.N)))
def test_small_dataset(self):
m = ExternalMerger(self.agg, 1000)
m.mergeValues(self.data)
self.assertEqual(m.spills, 0)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
self.assertEqual(sum(sum(v) for k, v in m.items()),
sum(xrange(self.N)))
m = ExternalMerger(self.agg, 1000)
m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data))
self.assertEqual(m.spills, 0)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
self.assertEqual(sum(sum(v) for k, v in m.items()),
sum(xrange(self.N)))
def test_medium_dataset(self):
m = ExternalMerger(self.agg, 30)
m = ExternalMerger(self.agg, 20)
m.mergeValues(self.data)
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
self.assertEqual(sum(sum(v) for k, v in m.items()),
sum(xrange(self.N)))
m = ExternalMerger(self.agg, 10)
m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3))
m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3))
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
self.assertEqual(sum(sum(v) for k, v in m.items()),
sum(xrange(self.N)) * 3)
def test_huge_dataset(self):
m = ExternalMerger(self.agg, 10, partitions=3)
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
m = ExternalMerger(self.agg, 5, partitions=3)
m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10))
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(len(v) for k, v in m.iteritems()),
self.assertEqual(sum(len(v) for k, v in m.items()),
self.N * 10)
m._cleanup()
@ -144,55 +149,55 @@ class MergerTests(unittest.TestCase):
self.assertEqual(1, len(list(gen_gs(1))))
self.assertEqual(2, len(list(gen_gs(2))))
self.assertEqual(100, len(list(gen_gs(100))))
self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)])
self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100)))
self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)])
self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100)))
for k, vs in gen_gs(50002, 10000):
self.assertEqual(k, len(vs))
self.assertEqual(range(k), list(vs))
self.assertEqual(list(range(k)), list(vs))
ser = PickleSerializer()
l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
for k, vs in l:
self.assertEqual(k, len(vs))
self.assertEqual(range(k), list(vs))
self.assertEqual(list(range(k)), list(vs))
class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
l = range(1024)
l = list(range(1024))
random.shuffle(l)
sorter = ExternalSorter(1024)
self.assertEquals(sorted(l), list(sorter.sorted(l)))
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
self.assertEqual(sorted(l), list(sorter.sorted(l)))
self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
def test_external_sort(self):
l = range(1024)
l = list(range(1024))
random.shuffle(l)
sorter = ExternalSorter(1)
self.assertEquals(sorted(l), list(sorter.sorted(l)))
self.assertEqual(sorted(l), list(sorter.sorted(l)))
self.assertGreater(shuffle.DiskBytesSpilled, 0)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
self.assertGreater(shuffle.DiskBytesSpilled, last)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
self.assertGreater(shuffle.DiskBytesSpilled, last)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
self.assertGreater(shuffle.DiskBytesSpilled, last)
def test_external_sort_in_rdd(self):
conf = SparkConf().set("spark.python.worker.memory", "1m")
sc = SparkContext(conf=conf)
l = range(10240)
l = list(range(10240))
random.shuffle(l)
rdd = sc.parallelize(l, 10)
self.assertEquals(sorted(l), rdd.sortBy(lambda x: x).collect())
rdd = sc.parallelize(l, 2)
self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
sc.stop()
@ -200,11 +205,11 @@ class SerializationTestCase(unittest.TestCase):
def test_namedtuple(self):
from collections import namedtuple
from cPickle import dumps, loads
from pickle import dumps, loads
P = namedtuple("P", "x y")
p1 = P(1, 3)
p2 = loads(dumps(p1, 2))
self.assertEquals(p1, p2)
self.assertEqual(p1, p2)
def test_itemgetter(self):
from operator import itemgetter
@ -246,7 +251,7 @@ class SerializationTestCase(unittest.TestCase):
ser = CloudPickleSerializer()
out1 = sys.stderr
out2 = ser.loads(ser.dumps(out1))
self.assertEquals(out1, out2)
self.assertEqual(out1, out2)
def test_func_globals(self):
@ -263,19 +268,36 @@ class SerializationTestCase(unittest.TestCase):
def foo():
sys.exit(0)
self.assertTrue("exit" in foo.func_code.co_names)
self.assertTrue("exit" in foo.__code__.co_names)
ser.dumps(foo)
def test_compressed_serializer(self):
ser = CompressedSerializer(PickleSerializer())
from StringIO import StringIO
try:
from StringIO import StringIO
except ImportError:
from io import BytesIO as StringIO
io = StringIO()
ser.dump_stream(["abc", u"123", range(5)], io)
io.seek(0)
self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
ser.dump_stream(range(1000), io)
io.seek(0)
self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io)))
self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io)))
io.close()
def test_hash_serializer(self):
hash(NoOpSerializer())
hash(UTF8Deserializer())
hash(PickleSerializer())
hash(MarshalSerializer())
hash(AutoSerializer())
hash(BatchedSerializer(PickleSerializer()))
hash(AutoBatchedSerializer(MarshalSerializer()))
hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer()))
hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer()))
hash(CompressedSerializer(PickleSerializer()))
hash(FlattenedValuesSerializer(PickleSerializer()))
class PySparkTestCase(unittest.TestCase):
@ -340,7 +362,7 @@ class CheckpointTests(ReusedPySparkTestCase):
self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
flatMappedRDD._jrdd_deserializer)
self.assertEquals([1, 2, 3, 4], recovered.collect())
self.assertEqual([1, 2, 3, 4], recovered.collect())
class AddFileTests(PySparkTestCase):
@ -356,8 +378,7 @@ class AddFileTests(PySparkTestCase):
def func(x):
from userlibrary import UserClass
return UserClass().hello()
self.assertRaises(Exception,
self.sc.parallelize(range(2)).map(func).first)
self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first)
log4j.LogManager.getRootLogger().setLevel(old_level)
# Add the file, so the job should now succeed:
@ -372,7 +393,7 @@ class AddFileTests(PySparkTestCase):
download_path = SparkFiles.get("hello.txt")
self.assertNotEqual(path, download_path)
with open(download_path) as test_file:
self.assertEquals("Hello World!\n", test_file.readline())
self.assertEqual("Hello World!\n", test_file.readline())
def test_add_py_file_locally(self):
# To ensure that we're actually testing addPyFile's effects, check that
@ -381,7 +402,7 @@ class AddFileTests(PySparkTestCase):
from userlibrary import UserClass
self.assertRaises(ImportError, func)
path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
self.sc.addFile(path)
self.sc.addPyFile(path)
from userlibrary import UserClass
self.assertEqual("Hello World!", UserClass().hello())
@ -391,7 +412,7 @@ class AddFileTests(PySparkTestCase):
def func():
from userlib import UserClass
self.assertRaises(ImportError, func)
path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg")
path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip")
self.sc.addPyFile(path)
from userlib import UserClass
self.assertEqual("Hello World from inside a package!", UserClass().hello())
@ -427,8 +448,9 @@ class RDDTests(ReusedPySparkTestCase):
tempFile = tempfile.NamedTemporaryFile(delete=True)
tempFile.close()
data.saveAsTextFile(tempFile.name)
raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*")))
self.assertEqual(x, unicode(raw_contents.strip(), "utf-8"))
raw_contents = b''.join(open(p, 'rb').read()
for p in glob(tempFile.name + "/part-0000*"))
self.assertEqual(x, raw_contents.strip().decode("utf-8"))
def test_save_as_textfile_with_utf8(self):
x = u"\u00A1Hola, mundo!"
@ -436,19 +458,20 @@ class RDDTests(ReusedPySparkTestCase):
tempFile = tempfile.NamedTemporaryFile(delete=True)
tempFile.close()
data.saveAsTextFile(tempFile.name)
raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*")))
self.assertEqual(x, unicode(raw_contents.strip(), "utf-8"))
raw_contents = b''.join(open(p, 'rb').read()
for p in glob(tempFile.name + "/part-0000*"))
self.assertEqual(x, raw_contents.strip().decode('utf8'))
def test_transforming_cartesian_result(self):
# Regression test for SPARK-1034
rdd1 = self.sc.parallelize([1, 2])
rdd2 = self.sc.parallelize([3, 4])
cart = rdd1.cartesian(rdd2)
result = cart.map(lambda (x, y): x + y).collect()
result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect()
def test_transforming_pickle_file(self):
# Regression test for SPARK-2601
data = self.sc.parallelize(["Hello", "World!"])
data = self.sc.parallelize([u"Hello", u"World!"])
tempFile = tempfile.NamedTemporaryFile(delete=True)
tempFile.close()
data.saveAsPickleFile(tempFile.name)
@ -461,13 +484,13 @@ class RDDTests(ReusedPySparkTestCase):
a = self.sc.textFile(path)
result = a.cartesian(a).collect()
(x, y) = result[0]
self.assertEqual("Hello World!", x.strip())
self.assertEqual("Hello World!", y.strip())
self.assertEqual(u"Hello World!", x.strip())
self.assertEqual(u"Hello World!", y.strip())
def test_deleting_input_files(self):
# Regression test for SPARK-1025
tempFile = tempfile.NamedTemporaryFile(delete=False)
tempFile.write("Hello World!")
tempFile.write(b"Hello World!")
tempFile.close()
data = self.sc.textFile(tempFile.name)
filtered_data = data.filter(lambda x: True)
@ -510,21 +533,21 @@ class RDDTests(ReusedPySparkTestCase):
jon = Person(1, "Jon", "Doe")
jane = Person(2, "Jane", "Doe")
theDoes = self.sc.parallelize([jon, jane])
self.assertEquals([jon, jane], theDoes.collect())
self.assertEqual([jon, jane], theDoes.collect())
def test_large_broadcast(self):
N = 100000
data = [[float(i) for i in range(300)] for i in range(N)]
bdata = self.sc.broadcast(data) # 270MB
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)
self.assertEqual(N, m)
def test_multiple_broadcasts(self):
N = 1 << 21
b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
r = range(1 << 15)
r = list(range(1 << 15))
random.shuffle(r)
s = str(r)
s = str(r).encode()
checksum = hashlib.md5(s).hexdigest()
b2 = self.sc.broadcast(s)
r = list(set(self.sc.parallelize(range(10), 10).map(
@ -535,7 +558,7 @@ class RDDTests(ReusedPySparkTestCase):
self.assertEqual(checksum, csum)
random.shuffle(r)
s = str(r)
s = str(r).encode()
checksum = hashlib.md5(s).hexdigest()
b2 = self.sc.broadcast(s)
r = list(set(self.sc.parallelize(range(10), 10).map(
@ -549,7 +572,7 @@ class RDDTests(ReusedPySparkTestCase):
N = 1000000
data = [float(i) for i in xrange(N)]
rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
self.assertEquals(N, rdd.first())
self.assertEqual(N, rdd.first())
# regression test for SPARK-6886
self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
@ -590,15 +613,15 @@ class RDDTests(ReusedPySparkTestCase):
# same total number of items, but different distributions
a = self.sc.parallelize([2, 3], 2).flatMap(range)
b = self.sc.parallelize([3, 2], 2).flatMap(range)
self.assertEquals(a.count(), b.count())
self.assertEqual(a.count(), b.count())
self.assertRaises(Exception, lambda: a.zip(b).count())
def test_count_approx_distinct(self):
rdd = self.sc.parallelize(range(1000))
self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050)
self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050)
self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050)
self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050)
self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050)
rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
self.assertTrue(18 < rdd.countApproxDistinct() < 22)
@ -612,59 +635,59 @@ class RDDTests(ReusedPySparkTestCase):
def test_histogram(self):
# empty
rdd = self.sc.parallelize([])
self.assertEquals([0], rdd.histogram([0, 10])[1])
self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1])
self.assertEqual([0], rdd.histogram([0, 10])[1])
self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
self.assertRaises(ValueError, lambda: rdd.histogram(1))
# out of range
rdd = self.sc.parallelize([10.01, -0.01])
self.assertEquals([0], rdd.histogram([0, 10])[1])
self.assertEquals([0, 0], rdd.histogram((0, 4, 10))[1])
self.assertEqual([0], rdd.histogram([0, 10])[1])
self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1])
# in range with one bucket
rdd = self.sc.parallelize(range(1, 5))
self.assertEquals([4], rdd.histogram([0, 10])[1])
self.assertEquals([3, 1], rdd.histogram([0, 4, 10])[1])
self.assertEqual([4], rdd.histogram([0, 10])[1])
self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1])
# in range with one bucket exact match
self.assertEquals([4], rdd.histogram([1, 4])[1])
self.assertEqual([4], rdd.histogram([1, 4])[1])
# out of range with two buckets
rdd = self.sc.parallelize([10.01, -0.01])
self.assertEquals([0, 0], rdd.histogram([0, 5, 10])[1])
self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1])
# out of range with two uneven buckets
rdd = self.sc.parallelize([10.01, -0.01])
self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1])
self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
# in range with two buckets
rdd = self.sc.parallelize([1, 2, 3, 5, 6])
self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1])
self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
# in range with two bucket and None
rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')])
self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1])
self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
# in range with two uneven buckets
rdd = self.sc.parallelize([1, 2, 3, 5, 6])
self.assertEquals([3, 2], rdd.histogram([0, 5, 11])[1])
self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1])
# mixed range with two uneven buckets
rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01])
self.assertEquals([4, 3], rdd.histogram([0, 5, 11])[1])
self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1])
# mixed range with four uneven buckets
rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1])
self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
# mixed range with uneven buckets and NaN
rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0,
199.0, 200.0, 200.1, None, float('nan')])
self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
# out of range with infinite buckets
rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")])
self.assertEquals([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1])
self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1])
# invalid buckets
self.assertRaises(ValueError, lambda: rdd.histogram([]))
@ -674,25 +697,25 @@ class RDDTests(ReusedPySparkTestCase):
# without buckets
rdd = self.sc.parallelize(range(1, 5))
self.assertEquals(([1, 4], [4]), rdd.histogram(1))
self.assertEqual(([1, 4], [4]), rdd.histogram(1))
# without buckets single element
rdd = self.sc.parallelize([1])
self.assertEquals(([1, 1], [1]), rdd.histogram(1))
self.assertEqual(([1, 1], [1]), rdd.histogram(1))
# without bucket no range
rdd = self.sc.parallelize([1] * 4)
self.assertEquals(([1, 1], [4]), rdd.histogram(1))
self.assertEqual(([1, 1], [4]), rdd.histogram(1))
# without buckets basic two
rdd = self.sc.parallelize(range(1, 5))
self.assertEquals(([1, 2.5, 4], [2, 2]), rdd.histogram(2))
self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2))
# without buckets with more requested than elements
rdd = self.sc.parallelize([1, 2])
buckets = [1 + 0.2 * i for i in range(6)]
hist = [1, 0, 0, 0, 1]
self.assertEquals((buckets, hist), rdd.histogram(5))
self.assertEqual((buckets, hist), rdd.histogram(5))
# invalid RDDs
rdd = self.sc.parallelize([1, float('inf')])
@ -702,15 +725,8 @@ class RDDTests(ReusedPySparkTestCase):
# string
rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2)
self.assertEquals([2, 2], rdd.histogram(["a", "b", "c"])[1])
self.assertEquals((["ab", "ef"], [5]), rdd.histogram(1))
self.assertRaises(TypeError, lambda: rdd.histogram(2))
# mixed RDD
rdd = self.sc.parallelize([1, 4, "ab", "ac", "b"], 2)
self.assertEquals([1, 1], rdd.histogram([0, 4, 10])[1])
self.assertEquals([2, 1], rdd.histogram(["a", "b", "c"])[1])
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1])
self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1))
self.assertRaises(TypeError, lambda: rdd.histogram(2))
def test_repartitionAndSortWithinPartitions(self):
@ -718,31 +734,31 @@ class RDDTests(ReusedPySparkTestCase):
repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
partitions = repartitioned.glom().collect()
self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)])
self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])
self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)])
self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)])
def test_distinct(self):
rdd = self.sc.parallelize((1, 2, 3)*10, 10)
self.assertEquals(rdd.getNumPartitions(), 10)
self.assertEquals(rdd.distinct().count(), 3)
self.assertEqual(rdd.getNumPartitions(), 10)
self.assertEqual(rdd.distinct().count(), 3)
result = rdd.distinct(5)
self.assertEquals(result.getNumPartitions(), 5)
self.assertEquals(result.count(), 3)
self.assertEqual(result.getNumPartitions(), 5)
self.assertEqual(result.count(), 3)
def test_external_group_by_key(self):
self.sc._conf.set("spark.python.worker.memory", "5m")
self.sc._conf.set("spark.python.worker.memory", "1m")
N = 200001
kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
gkv = kv.groupByKey().cache()
self.assertEqual(3, gkv.count())
filtered = gkv.filter(lambda (k, vs): k == 1)
filtered = gkv.filter(lambda kv: kv[0] == 1)
self.assertEqual(1, filtered.count())
self.assertEqual([(1, N/3)], filtered.mapValues(len).collect())
self.assertEqual([(N/3, N/3)],
self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect())
self.assertEqual([(N // 3, N // 3)],
filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
result = filtered.collect()[0][1]
self.assertEqual(N/3, len(result))
self.assertTrue(isinstance(result.data, shuffle.ExternalList))
self.assertEqual(N // 3, len(result))
self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList))
def test_sort_on_empty_rdd(self):
self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
@ -767,7 +783,7 @@ class RDDTests(ReusedPySparkTestCase):
rdd = RDD(jrdd, self.sc, UTF8Deserializer())
self.assertEqual([u"a", None, u"b"], rdd.collect())
rdd = RDD(jrdd, self.sc, NoOpSerializer())
self.assertEqual(["a", None, "b"], rdd.collect())
self.assertEqual([b"a", None, b"b"], rdd.collect())
def test_multiple_python_java_RDD_conversions(self):
# Regression test for SPARK-5361
@ -813,14 +829,14 @@ class RDDTests(ReusedPySparkTestCase):
self.sc.setJobGroup("test3", "test", True)
d = sorted(parted.cogroup(parted).collect())
self.assertEqual(10, len(d))
self.assertEqual([[0], [0]], map(list, d[0][1]))
self.assertEqual([[0], [0]], list(map(list, d[0][1])))
jobId = tracker.getJobIdsForGroup("test3")[0]
self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
self.sc.setJobGroup("test4", "test", True)
d = sorted(parted.cogroup(rdd).collect())
self.assertEqual(10, len(d))
self.assertEqual([[0], [0]], map(list, d[0][1]))
self.assertEqual([[0], [0]], list(map(list, d[0][1])))
jobId = tracker.getJobIdsForGroup("test4")[0]
self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
@ -906,6 +922,7 @@ class InputFormatTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name)
@unittest.skipIf(sys.version >= "3", "serialize array of byte")
def test_sequencefiles(self):
basepath = self.tempdir.name
ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/",
@ -954,15 +971,16 @@ class InputFormatTests(ReusedPySparkTestCase):
en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
self.assertEqual(nulls, en)
maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/",
"org.apache.hadoop.io.IntWritable",
"org.apache.hadoop.io.MapWritable").collect())
maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/",
"org.apache.hadoop.io.IntWritable",
"org.apache.hadoop.io.MapWritable").collect()
em = [(1, {}),
(1, {3.0: u'bb'}),
(2, {1.0: u'aa'}),
(2, {1.0: u'cc'}),
(3, {2.0: u'dd'})]
self.assertEqual(maps, em)
for v in maps:
self.assertTrue(v in em)
# arrays get pickled to tuples by default
tuples = sorted(self.sc.sequenceFile(
@ -1089,8 +1107,8 @@ class InputFormatTests(ReusedPySparkTestCase):
def test_binary_files(self):
path = os.path.join(self.tempdir.name, "binaryfiles")
os.mkdir(path)
data = "short binary data"
with open(os.path.join(path, "part-0000"), 'w') as f:
data = b"short binary data"
with open(os.path.join(path, "part-0000"), 'wb') as f:
f.write(data)
[(p, d)] = self.sc.binaryFiles(path).collect()
self.assertTrue(p.endswith("part-0000"))
@ -1103,7 +1121,7 @@ class InputFormatTests(ReusedPySparkTestCase):
for i in range(100):
f.write('%04d' % i)
result = self.sc.binaryRecords(path, 4).map(int).collect()
self.assertEqual(range(100), result)
self.assertEqual(list(range(100)), result)
class OutputFormatTests(ReusedPySparkTestCase):
@ -1115,6 +1133,7 @@ class OutputFormatTests(ReusedPySparkTestCase):
def tearDown(self):
shutil.rmtree(self.tempdir.name, ignore_errors=True)
@unittest.skipIf(sys.version >= "3", "serialize array of byte")
def test_sequencefiles(self):
basepath = self.tempdir.name
ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
@ -1155,8 +1174,9 @@ class OutputFormatTests(ReusedPySparkTestCase):
(2, {1.0: u'cc'}),
(3, {2.0: u'dd'})]
self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/")
maps = sorted(self.sc.sequenceFile(basepath + "/sfmap/").collect())
self.assertEqual(maps, em)
maps = self.sc.sequenceFile(basepath + "/sfmap/").collect()
for v in maps:
self.assertTrue(v, em)
def test_oldhadoop(self):
basepath = self.tempdir.name
@ -1168,12 +1188,13 @@ class OutputFormatTests(ReusedPySparkTestCase):
"org.apache.hadoop.mapred.SequenceFileOutputFormat",
"org.apache.hadoop.io.IntWritable",
"org.apache.hadoop.io.MapWritable")
result = sorted(self.sc.hadoopFile(
result = self.sc.hadoopFile(
basepath + "/oldhadoop/",
"org.apache.hadoop.mapred.SequenceFileInputFormat",
"org.apache.hadoop.io.IntWritable",
"org.apache.hadoop.io.MapWritable").collect())
self.assertEqual(result, dict_data)
"org.apache.hadoop.io.MapWritable").collect()
for v in result:
self.assertTrue(v, dict_data)
conf = {
"mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat",
@ -1183,12 +1204,13 @@ class OutputFormatTests(ReusedPySparkTestCase):
}
self.sc.parallelize(dict_data).saveAsHadoopDataset(conf)
input_conf = {"mapred.input.dir": basepath + "/olddataset/"}
old_dataset = sorted(self.sc.hadoopRDD(
result = self.sc.hadoopRDD(
"org.apache.hadoop.mapred.SequenceFileInputFormat",
"org.apache.hadoop.io.IntWritable",
"org.apache.hadoop.io.MapWritable",
conf=input_conf).collect())
self.assertEqual(old_dataset, dict_data)
conf=input_conf).collect()
for v in result:
self.assertTrue(v, dict_data)
def test_newhadoop(self):
basepath = self.tempdir.name
@ -1223,6 +1245,7 @@ class OutputFormatTests(ReusedPySparkTestCase):
conf=input_conf).collect())
self.assertEqual(new_dataset, data)
@unittest.skipIf(sys.version >= "3", "serialize of array")
def test_newhadoop_with_array(self):
basepath = self.tempdir.name
# use custom ArrayWritable types and converters to handle arrays
@ -1303,7 +1326,7 @@ class OutputFormatTests(ReusedPySparkTestCase):
basepath = self.tempdir.name
x = range(1, 5)
y = range(1001, 1005)
data = zip(x, y)
data = list(zip(x, y))
rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y))
rdd.saveAsSequenceFile(basepath + "/reserialize/sequence")
result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect())
@ -1354,7 +1377,7 @@ class DaemonTests(unittest.TestCase):
sock = socket(AF_INET, SOCK_STREAM)
sock.connect(('127.0.0.1', port))
# send a split index of -1 to shutdown the worker
sock.send("\xFF\xFF\xFF\xFF")
sock.send(b"\xFF\xFF\xFF\xFF")
sock.close()
return True
@ -1395,7 +1418,6 @@ class DaemonTests(unittest.TestCase):
class WorkerTests(PySparkTestCase):
def test_cancel_task(self):
temp = tempfile.NamedTemporaryFile(delete=True)
temp.close()
@ -1410,7 +1432,7 @@ class WorkerTests(PySparkTestCase):
# start job in background thread
def run():
self.sc.parallelize(range(1)).foreach(sleep)
self.sc.parallelize(range(1), 1).foreach(sleep)
import threading
t = threading.Thread(target=run)
t.daemon = True
@ -1419,7 +1441,8 @@ class WorkerTests(PySparkTestCase):
daemon_pid, worker_pid = 0, 0
while True:
if os.path.exists(path):
data = open(path).read().split(' ')
with open(path) as f:
data = f.read().split(' ')
daemon_pid, worker_pid = map(int, data)
break
time.sleep(0.1)
@ -1455,7 +1478,7 @@ class WorkerTests(PySparkTestCase):
def test_after_jvm_exception(self):
tempFile = tempfile.NamedTemporaryFile(delete=False)
tempFile.write("Hello World!")
tempFile.write(b"Hello World!")
tempFile.close()
data = self.sc.textFile(tempFile.name, 1)
filtered_data = data.filter(lambda x: True)
@ -1577,12 +1600,12 @@ class SparkSubmitTests(unittest.TestCase):
|from pyspark import SparkContext
|
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()
|print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect())
""")
proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 4, 6]", out)
self.assertIn("[2, 4, 6]", out.decode('utf-8'))
def test_script_with_local_functions(self):
"""Submit and test a single script file calling a global function"""
@ -1593,12 +1616,12 @@ class SparkSubmitTests(unittest.TestCase):
| return x * 3
|
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(foo).collect()
|print(sc.parallelize([1, 2, 3]).map(foo).collect())
""")
proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[3, 6, 9]", out)
self.assertIn("[3, 6, 9]", out.decode('utf-8'))
def test_module_dependency(self):
"""Submit and test a script with a dependency on another module"""
@ -1607,7 +1630,7 @@ class SparkSubmitTests(unittest.TestCase):
|from mylib import myfunc
|
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(myfunc).collect()
|print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
""")
zip = self.createFileInZip("mylib.py", """
|def myfunc(x):
@ -1617,7 +1640,7 @@ class SparkSubmitTests(unittest.TestCase):
stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 3, 4]", out)
self.assertIn("[2, 3, 4]", out.decode('utf-8'))
def test_module_dependency_on_cluster(self):
"""Submit and test a script with a dependency on another module on a cluster"""
@ -1626,7 +1649,7 @@ class SparkSubmitTests(unittest.TestCase):
|from mylib import myfunc
|
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(myfunc).collect()
|print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
""")
zip = self.createFileInZip("mylib.py", """
|def myfunc(x):
@ -1637,7 +1660,7 @@ class SparkSubmitTests(unittest.TestCase):
stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 3, 4]", out)
self.assertIn("[2, 3, 4]", out.decode('utf-8'))
def test_package_dependency(self):
"""Submit and test a script with a dependency on a Spark Package"""
@ -1646,14 +1669,14 @@ class SparkSubmitTests(unittest.TestCase):
|from mylib import myfunc
|
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(myfunc).collect()
|print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
""")
self.create_spark_package("a:mylib:0.1")
proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
"file:" + self.programDir, script], stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 3, 4]", out)
self.assertIn("[2, 3, 4]", out.decode('utf-8'))
def test_package_dependency_on_cluster(self):
"""Submit and test a script with a dependency on a Spark Package on a cluster"""
@ -1662,7 +1685,7 @@ class SparkSubmitTests(unittest.TestCase):
|from mylib import myfunc
|
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(myfunc).collect()
|print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
""")
self.create_spark_package("a:mylib:0.1")
proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
@ -1670,7 +1693,7 @@ class SparkSubmitTests(unittest.TestCase):
"local-cluster[1,1,512]", script], stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 3, 4]", out)
self.assertIn("[2, 3, 4]", out.decode('utf-8'))
def test_single_script_on_cluster(self):
"""Submit and test a single script on a cluster"""
@ -1681,7 +1704,7 @@ class SparkSubmitTests(unittest.TestCase):
| return x * 2
|
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(foo).collect()
|print(sc.parallelize([1, 2, 3]).map(foo).collect())
""")
# this will fail if you have different spark.executor.memory
# in conf/spark-defaults.conf
@ -1690,7 +1713,7 @@ class SparkSubmitTests(unittest.TestCase):
stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 4, 6]", out)
self.assertIn("[2, 4, 6]", out.decode('utf-8'))
class ContextTests(unittest.TestCase):
@ -1765,7 +1788,7 @@ class SciPyTests(PySparkTestCase):
def test_serialize(self):
from scipy.special import gammaln
x = range(1, 5)
expected = map(gammaln, x)
expected = list(map(gammaln, x))
observed = self.sc.parallelize(x).map(gammaln).collect()
self.assertEqual(expected, observed)
@ -1786,11 +1809,11 @@ class NumPyTests(PySparkTestCase):
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
if not _have_numpy:
print "NOTE: Skipping NumPy tests as it does not seem to be installed"
print("NOTE: Skipping NumPy tests as it does not seem to be installed")
unittest.main()
if not _have_scipy:
print "NOTE: SciPy tests were skipped as it does not seem to be installed"
print("NOTE: SciPy tests were skipped as it does not seem to be installed")
if not _have_numpy:
print "NOTE: NumPy tests were skipped as it does not seem to be installed"
print("NOTE: NumPy tests were skipped as it does not seem to be installed")

View file

@ -18,6 +18,7 @@
"""
Worker that receives input from Piped RDD.
"""
from __future__ import print_function
import os
import sys
import time
@ -37,9 +38,9 @@ utf8_deserializer = UTF8Deserializer()
def report_times(outfile, boot, init, finish):
write_int(SpecialLengths.TIMING_DATA, outfile)
write_long(1000 * boot, outfile)
write_long(1000 * init, outfile)
write_long(1000 * finish, outfile)
write_long(int(1000 * boot), outfile)
write_long(int(1000 * init), outfile)
write_long(int(1000 * finish), outfile)
def add_path(path):
@ -72,6 +73,9 @@ def main(infile, outfile):
for _ in range(num_python_includes):
filename = utf8_deserializer.loads(infile)
add_path(os.path.join(spark_files_dir, filename))
if sys.version > '3':
import importlib
importlib.invalidate_caches()
# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
@ -106,14 +110,14 @@ def main(infile, outfile):
except Exception:
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
write_with_length(traceback.format_exc().encode("utf-8"), outfile)
except IOError:
# JVM close the socket
pass
except Exception:
# Write the error to stderr if it happened while serializing
print >> sys.stderr, "PySpark worker failed with exception:"
print >> sys.stderr, traceback.format_exc()
print("PySpark worker failed with exception:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)

View file

@ -66,7 +66,7 @@ function run_core_tests() {
function run_sql_tests() {
echo "Run sql tests ..."
run_test "pyspark/sql/types.py"
run_test "pyspark/sql/_types.py"
run_test "pyspark/sql/context.py"
run_test "pyspark/sql/dataframe.py"
run_test "pyspark/sql/functions.py"
@ -136,6 +136,19 @@ run_mllib_tests
run_ml_tests
run_streaming_tests
# Try to test with Python 3
if [ $(which python3.4) ]; then
export PYSPARK_PYTHON="python3.4"
echo "Testing with Python3.4 version:"
$PYSPARK_PYTHON --version
run_core_tests
run_sql_tests
run_mllib_tests
run_ml_tests
run_streaming_tests
fi
# Try to test with PyPy
if [ $(which pypy) ]; then
export PYSPARK_PYTHON="pypy"

Binary file not shown.