[SPARK-4897] [PySpark] Python 3 support
This PR update PySpark to support Python 3 (tested with 3.4). Known issue: unpickle array from Pyrolite is broken in Python 3, those tests are skipped. TODO: ec2/spark-ec2.py is not fully tested with python3. Author: Davies Liu <davies@databricks.com> Author: twneale <twneale@gmail.com> Author: Josh Rosen <joshrosen@databricks.com> Closes #5173 from davies/python3 and squashes the following commits: d7d6323 [Davies Liu] fix tests 6c52a98 [Davies Liu] fix mllib test 99e334f [Davies Liu] update timeout b716610 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 cafd5ec [Davies Liu] adddress comments from @mengxr bf225d7 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 179fc8d [Davies Liu] tuning flaky tests 8c8b957 [Davies Liu] fix ResourceWarning in Python 3 5c57c95 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 4006829 [Davies Liu] fix test 2fc0066 [Davies Liu] add python3 path 71535e9 [Davies Liu] fix xrange and divide 5a55ab4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 125f12c [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 ed498c8 [Davies Liu] fix compatibility with python 3 820e649 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 e8ce8c9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 ad7c374 [Davies Liu] fix mllib test and warning ef1fc2f [Davies Liu] fix tests 4eee14a [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 20112ff [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 59bb492 [Davies Liu] fix tests 1da268c [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 ca0fdd3 [Davies Liu] fix code style 9563a15 [Davies Liu] add imap back for python 2 0b1ec04 [Davies Liu] make python examples work with Python 3 d2fd566 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 a716d34 [Davies Liu] test with python 3.4 f1700e8 [Davies Liu] fix test in python3 671b1db [Davies Liu] fix test in python3 692ff47 [Davies Liu] fix flaky test 7b9699f [Davies Liu] invalidate import cache for Python 3.3+ 9c58497 [Davies Liu] fix kill worker 309bfbf [Davies Liu] keep compatibility 5707476 [Davies Liu] cleanup, fix hash of string in 3.3+ 8662d5b [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 f53e1f0 [Davies Liu] fix tests 70b6b73 [Davies Liu] compile ec2/spark_ec2.py in python 3 a39167e [Davies Liu] support customize class in __main__ 814c77b [Davies Liu] run unittests with python 3 7f4476e [Davies Liu] mllib tests passed d737924 [Davies Liu] pass ml tests 375ea17 [Davies Liu] SQL tests pass 6cc42a9 [Davies Liu] rename 431a8de [Davies Liu] streaming tests pass 78901a7 [Davies Liu] fix hash of serializer in Python 3 24b2f2e [Davies Liu] pass all RDD tests 35f48fe [Davies Liu] run future again 1eebac2 [Davies Liu] fix conflict in ec2/spark_ec2.py 6e3c21d [Davies Liu] make cloudpickle work with Python3 2fb2db3 [Josh Rosen] Guard more changes behind sys.version; still doesn't run 1aa5e8f [twneale] Turned out `pickle.DictionaryType is dict` == True, so swapped it out 7354371 [twneale] buffer --> memoryview I'm not super sure if this a valid change, but the 2.7 docs recommend using memoryview over buffer where possible, so hoping it'll work. b69ccdf [twneale] Uses the pure python pickle._Pickler instead of c-extension _pickle.Pickler. It appears pyspark 2.7 uses the pure python pickler as well, so this shouldn't degrade pickling performance (?). f40d925 [twneale] xrange --> range e104215 [twneale] Replaces 2.7 types.InstsanceType with 3.4 `object`....could be horribly wrong depending on how types.InstanceType is used elsewhere in the package--see http://bugs.python.org/issue8206 79de9d0 [twneale] Replaces python2.7 `file` with 3.4 _io.TextIOWrapper 2adb42d [Josh Rosen] Fix up some import differences between Python 2 and 3 854be27 [Josh Rosen] Run `futurize` on Python code: 7c5b4ce [Josh Rosen] Remove Python 3 check in shell.py.
This commit is contained in:
parent
55f553a979
commit
04e44b37cc
|
@ -89,6 +89,7 @@ export PYTHONSTARTUP="$SPARK_HOME/python/pyspark/shell.py"
|
|||
if [[ -n "$SPARK_TESTING" ]]; then
|
||||
unset YARN_CONF_DIR
|
||||
unset HADOOP_CONF_DIR
|
||||
export PYTHONHASHSEED=0
|
||||
if [[ -n "$PYSPARK_DOC_TEST" ]]; then
|
||||
exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1
|
||||
else
|
||||
|
|
|
@ -19,6 +19,9 @@
|
|||
|
||||
SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
|
||||
|
||||
# disable randomized hash for string in Python 3.3+
|
||||
export PYTHONHASHSEED=0
|
||||
|
||||
# Only define a usage function if an upstream script hasn't done so.
|
||||
if ! type -t usage >/dev/null 2>&1; then
|
||||
usage() {
|
||||
|
|
|
@ -20,6 +20,9 @@ rem
|
|||
rem This is the entry point for running Spark submit. To avoid polluting the
|
||||
rem environment, it just launches a new cmd to do the real work.
|
||||
|
||||
rem disable randomized hash for string in Python 3.3+
|
||||
set PYTHONHASHSEED=0
|
||||
|
||||
set CLASS=org.apache.spark.deploy.SparkSubmit
|
||||
call %~dp0spark-class2.cmd %CLASS% %*
|
||||
set SPARK_ERROR_LEVEL=%ERRORLEVEL%
|
||||
|
|
|
@ -235,6 +235,8 @@ echo "========================================================================="
|
|||
|
||||
CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS
|
||||
|
||||
# add path for python 3 in jenkins
|
||||
export PATH="${PATH}:/home/anaonda/envs/py3k/bin"
|
||||
./python/run-tests
|
||||
|
||||
echo ""
|
||||
|
|
|
@ -47,7 +47,7 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}"
|
|||
# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :(
|
||||
SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}"
|
||||
|
||||
TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout
|
||||
TESTS_TIMEOUT="150m" # format: http://linux.die.net/man/1/timeout
|
||||
|
||||
# Array to capture all tests to run on the pull request. These tests are held under the
|
||||
#+ dev/tests/ directory.
|
||||
|
|
262
ec2/spark_ec2.py
262
ec2/spark_ec2.py
|
@ -19,7 +19,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import with_statement
|
||||
from __future__ import with_statement, print_function
|
||||
|
||||
import hashlib
|
||||
import itertools
|
||||
|
@ -37,12 +37,17 @@ import tarfile
|
|||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import urllib2
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from optparse import OptionParser
|
||||
from sys import stderr
|
||||
|
||||
if sys.version < "3":
|
||||
from urllib2 import urlopen, Request, HTTPError
|
||||
else:
|
||||
from urllib.request import urlopen, Request
|
||||
from urllib.error import HTTPError
|
||||
|
||||
SPARK_EC2_VERSION = "1.2.1"
|
||||
SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
@ -88,10 +93,10 @@ def setup_external_libs(libs):
|
|||
SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib")
|
||||
|
||||
if not os.path.exists(SPARK_EC2_LIB_DIR):
|
||||
print "Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format(
|
||||
print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format(
|
||||
path=SPARK_EC2_LIB_DIR
|
||||
)
|
||||
print "This should be a one-time operation."
|
||||
))
|
||||
print("This should be a one-time operation.")
|
||||
os.mkdir(SPARK_EC2_LIB_DIR)
|
||||
|
||||
for lib in libs:
|
||||
|
@ -100,8 +105,8 @@ def setup_external_libs(libs):
|
|||
|
||||
if not os.path.isdir(lib_dir):
|
||||
tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz")
|
||||
print " - Downloading {lib}...".format(lib=lib["name"])
|
||||
download_stream = urllib2.urlopen(
|
||||
print(" - Downloading {lib}...".format(lib=lib["name"]))
|
||||
download_stream = urlopen(
|
||||
"{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format(
|
||||
prefix=PYPI_URL_PREFIX,
|
||||
first_letter=lib["name"][:1],
|
||||
|
@ -113,13 +118,13 @@ def setup_external_libs(libs):
|
|||
tgz_file.write(download_stream.read())
|
||||
with open(tgz_file_path) as tar:
|
||||
if hashlib.md5(tar.read()).hexdigest() != lib["md5"]:
|
||||
print >> stderr, "ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"])
|
||||
print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr)
|
||||
sys.exit(1)
|
||||
tar = tarfile.open(tgz_file_path)
|
||||
tar.extractall(path=SPARK_EC2_LIB_DIR)
|
||||
tar.close()
|
||||
os.remove(tgz_file_path)
|
||||
print " - Finished downloading {lib}.".format(lib=lib["name"])
|
||||
print(" - Finished downloading {lib}.".format(lib=lib["name"]))
|
||||
sys.path.insert(1, lib_dir)
|
||||
|
||||
|
||||
|
@ -299,12 +304,12 @@ def parse_args():
|
|||
if home_dir is None or not os.path.isfile(home_dir + '/.boto'):
|
||||
if not os.path.isfile('/etc/boto.cfg'):
|
||||
if os.getenv('AWS_ACCESS_KEY_ID') is None:
|
||||
print >> stderr, ("ERROR: The environment variable AWS_ACCESS_KEY_ID " +
|
||||
"must be set")
|
||||
print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set",
|
||||
file=stderr)
|
||||
sys.exit(1)
|
||||
if os.getenv('AWS_SECRET_ACCESS_KEY') is None:
|
||||
print >> stderr, ("ERROR: The environment variable AWS_SECRET_ACCESS_KEY " +
|
||||
"must be set")
|
||||
print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set",
|
||||
file=stderr)
|
||||
sys.exit(1)
|
||||
return (opts, action, cluster_name)
|
||||
|
||||
|
@ -316,7 +321,7 @@ def get_or_make_group(conn, name, vpc_id):
|
|||
if len(group) > 0:
|
||||
return group[0]
|
||||
else:
|
||||
print "Creating security group " + name
|
||||
print("Creating security group " + name)
|
||||
return conn.create_security_group(name, "Spark EC2 group", vpc_id)
|
||||
|
||||
|
||||
|
@ -324,18 +329,19 @@ def get_validate_spark_version(version, repo):
|
|||
if "." in version:
|
||||
version = version.replace("v", "")
|
||||
if version not in VALID_SPARK_VERSIONS:
|
||||
print >> stderr, "Don't know about Spark version: {v}".format(v=version)
|
||||
print("Don't know about Spark version: {v}".format(v=version), file=stderr)
|
||||
sys.exit(1)
|
||||
return version
|
||||
else:
|
||||
github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version)
|
||||
request = urllib2.Request(github_commit_url)
|
||||
request = Request(github_commit_url)
|
||||
request.get_method = lambda: 'HEAD'
|
||||
try:
|
||||
response = urllib2.urlopen(request)
|
||||
except urllib2.HTTPError, e:
|
||||
print >> stderr, "Couldn't validate Spark commit: {url}".format(url=github_commit_url)
|
||||
print >> stderr, "Received HTTP response code of {code}.".format(code=e.code)
|
||||
response = urlopen(request)
|
||||
except HTTPError as e:
|
||||
print("Couldn't validate Spark commit: {url}".format(url=github_commit_url),
|
||||
file=stderr)
|
||||
print("Received HTTP response code of {code}.".format(code=e.code), file=stderr)
|
||||
sys.exit(1)
|
||||
return version
|
||||
|
||||
|
@ -394,8 +400,7 @@ def get_spark_ami(opts):
|
|||
instance_type = EC2_INSTANCE_TYPES[opts.instance_type]
|
||||
else:
|
||||
instance_type = "pvm"
|
||||
print >> stderr,\
|
||||
"Don't recognize %s, assuming type is pvm" % opts.instance_type
|
||||
print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr)
|
||||
|
||||
# URL prefix from which to fetch AMI information
|
||||
ami_prefix = "{r}/{b}/ami-list".format(
|
||||
|
@ -404,10 +409,10 @@ def get_spark_ami(opts):
|
|||
|
||||
ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type)
|
||||
try:
|
||||
ami = urllib2.urlopen(ami_path).read().strip()
|
||||
print "Spark AMI: " + ami
|
||||
ami = urlopen(ami_path).read().strip()
|
||||
print("Spark AMI: " + ami)
|
||||
except:
|
||||
print >> stderr, "Could not resolve AMI at: " + ami_path
|
||||
print("Could not resolve AMI at: " + ami_path, file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
return ami
|
||||
|
@ -419,11 +424,11 @@ def get_spark_ami(opts):
|
|||
# Fails if there already instances running in the cluster's groups.
|
||||
def launch_cluster(conn, opts, cluster_name):
|
||||
if opts.identity_file is None:
|
||||
print >> stderr, "ERROR: Must provide an identity file (-i) for ssh connections."
|
||||
print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if opts.key_pair is None:
|
||||
print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances."
|
||||
print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
user_data_content = None
|
||||
|
@ -431,7 +436,7 @@ def launch_cluster(conn, opts, cluster_name):
|
|||
with open(opts.user_data) as user_data_file:
|
||||
user_data_content = user_data_file.read()
|
||||
|
||||
print "Setting up security groups..."
|
||||
print("Setting up security groups...")
|
||||
master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id)
|
||||
slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id)
|
||||
authorized_address = opts.authorized_address
|
||||
|
@ -497,8 +502,8 @@ def launch_cluster(conn, opts, cluster_name):
|
|||
existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name,
|
||||
die_on_error=False)
|
||||
if existing_slaves or (existing_masters and not opts.use_existing_master):
|
||||
print >> stderr, ("ERROR: There are already instances running in " +
|
||||
"group %s or %s" % (master_group.name, slave_group.name))
|
||||
print("ERROR: There are already instances running in group %s or %s" %
|
||||
(master_group.name, slave_group.name), file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Figure out Spark AMI
|
||||
|
@ -511,12 +516,12 @@ def launch_cluster(conn, opts, cluster_name):
|
|||
additional_group_ids = [sg.id
|
||||
for sg in conn.get_all_security_groups()
|
||||
if opts.additional_security_group in (sg.name, sg.id)]
|
||||
print "Launching instances..."
|
||||
print("Launching instances...")
|
||||
|
||||
try:
|
||||
image = conn.get_all_images(image_ids=[opts.ami])[0]
|
||||
except:
|
||||
print >> stderr, "Could not find AMI " + opts.ami
|
||||
print("Could not find AMI " + opts.ami, file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Create block device mapping so that we can add EBS volumes if asked to.
|
||||
|
@ -542,8 +547,8 @@ def launch_cluster(conn, opts, cluster_name):
|
|||
# Launch slaves
|
||||
if opts.spot_price is not None:
|
||||
# Launch spot instances with the requested price
|
||||
print ("Requesting %d slaves as spot instances with price $%.3f" %
|
||||
(opts.slaves, opts.spot_price))
|
||||
print("Requesting %d slaves as spot instances with price $%.3f" %
|
||||
(opts.slaves, opts.spot_price))
|
||||
zones = get_zones(conn, opts)
|
||||
num_zones = len(zones)
|
||||
i = 0
|
||||
|
@ -566,7 +571,7 @@ def launch_cluster(conn, opts, cluster_name):
|
|||
my_req_ids += [req.id for req in slave_reqs]
|
||||
i += 1
|
||||
|
||||
print "Waiting for spot instances to be granted..."
|
||||
print("Waiting for spot instances to be granted...")
|
||||
try:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
@ -579,24 +584,24 @@ def launch_cluster(conn, opts, cluster_name):
|
|||
if i in id_to_req and id_to_req[i].state == "active":
|
||||
active_instance_ids.append(id_to_req[i].instance_id)
|
||||
if len(active_instance_ids) == opts.slaves:
|
||||
print "All %d slaves granted" % opts.slaves
|
||||
print("All %d slaves granted" % opts.slaves)
|
||||
reservations = conn.get_all_reservations(active_instance_ids)
|
||||
slave_nodes = []
|
||||
for r in reservations:
|
||||
slave_nodes += r.instances
|
||||
break
|
||||
else:
|
||||
print "%d of %d slaves granted, waiting longer" % (
|
||||
len(active_instance_ids), opts.slaves)
|
||||
print("%d of %d slaves granted, waiting longer" % (
|
||||
len(active_instance_ids), opts.slaves))
|
||||
except:
|
||||
print "Canceling spot instance requests"
|
||||
print("Canceling spot instance requests")
|
||||
conn.cancel_spot_instance_requests(my_req_ids)
|
||||
# Log a warning if any of these requests actually launched instances:
|
||||
(master_nodes, slave_nodes) = get_existing_cluster(
|
||||
conn, opts, cluster_name, die_on_error=False)
|
||||
running = len(master_nodes) + len(slave_nodes)
|
||||
if running:
|
||||
print >> stderr, ("WARNING: %d instances are still running" % running)
|
||||
print(("WARNING: %d instances are still running" % running), file=stderr)
|
||||
sys.exit(0)
|
||||
else:
|
||||
# Launch non-spot instances
|
||||
|
@ -618,16 +623,16 @@ def launch_cluster(conn, opts, cluster_name):
|
|||
placement_group=opts.placement_group,
|
||||
user_data=user_data_content)
|
||||
slave_nodes += slave_res.instances
|
||||
print "Launched {s} slave{plural_s} in {z}, regid = {r}".format(
|
||||
s=num_slaves_this_zone,
|
||||
plural_s=('' if num_slaves_this_zone == 1 else 's'),
|
||||
z=zone,
|
||||
r=slave_res.id)
|
||||
print("Launched {s} slave{plural_s} in {z}, regid = {r}".format(
|
||||
s=num_slaves_this_zone,
|
||||
plural_s=('' if num_slaves_this_zone == 1 else 's'),
|
||||
z=zone,
|
||||
r=slave_res.id))
|
||||
i += 1
|
||||
|
||||
# Launch or resume masters
|
||||
if existing_masters:
|
||||
print "Starting master..."
|
||||
print("Starting master...")
|
||||
for inst in existing_masters:
|
||||
if inst.state not in ["shutting-down", "terminated"]:
|
||||
inst.start()
|
||||
|
@ -650,10 +655,10 @@ def launch_cluster(conn, opts, cluster_name):
|
|||
user_data=user_data_content)
|
||||
|
||||
master_nodes = master_res.instances
|
||||
print "Launched master in %s, regid = %s" % (zone, master_res.id)
|
||||
print("Launched master in %s, regid = %s" % (zone, master_res.id))
|
||||
|
||||
# This wait time corresponds to SPARK-4983
|
||||
print "Waiting for AWS to propagate instance metadata..."
|
||||
print("Waiting for AWS to propagate instance metadata...")
|
||||
time.sleep(5)
|
||||
# Give the instances descriptive names
|
||||
for master in master_nodes:
|
||||
|
@ -674,8 +679,8 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
|
|||
Get the EC2 instances in an existing cluster if available.
|
||||
Returns a tuple of lists of EC2 instance objects for the masters and slaves.
|
||||
"""
|
||||
print "Searching for existing cluster {c} in region {r}...".format(
|
||||
c=cluster_name, r=opts.region)
|
||||
print("Searching for existing cluster {c} in region {r}...".format(
|
||||
c=cluster_name, r=opts.region))
|
||||
|
||||
def get_instances(group_names):
|
||||
"""
|
||||
|
@ -693,16 +698,15 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
|
|||
slave_instances = get_instances([cluster_name + "-slaves"])
|
||||
|
||||
if any((master_instances, slave_instances)):
|
||||
print "Found {m} master{plural_m}, {s} slave{plural_s}.".format(
|
||||
m=len(master_instances),
|
||||
plural_m=('' if len(master_instances) == 1 else 's'),
|
||||
s=len(slave_instances),
|
||||
plural_s=('' if len(slave_instances) == 1 else 's'))
|
||||
print("Found {m} master{plural_m}, {s} slave{plural_s}.".format(
|
||||
m=len(master_instances),
|
||||
plural_m=('' if len(master_instances) == 1 else 's'),
|
||||
s=len(slave_instances),
|
||||
plural_s=('' if len(slave_instances) == 1 else 's')))
|
||||
|
||||
if not master_instances and die_on_error:
|
||||
print >> sys.stderr, \
|
||||
"ERROR: Could not find a master for cluster {c} in region {r}.".format(
|
||||
c=cluster_name, r=opts.region)
|
||||
print("ERROR: Could not find a master for cluster {c} in region {r}.".format(
|
||||
c=cluster_name, r=opts.region), file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
return (master_instances, slave_instances)
|
||||
|
@ -713,7 +717,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
|
|||
def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
|
||||
master = get_dns_name(master_nodes[0], opts.private_ips)
|
||||
if deploy_ssh_key:
|
||||
print "Generating cluster's SSH key on master..."
|
||||
print("Generating cluster's SSH key on master...")
|
||||
key_setup = """
|
||||
[ -f ~/.ssh/id_rsa ] ||
|
||||
(ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa &&
|
||||
|
@ -721,10 +725,10 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
|
|||
"""
|
||||
ssh(master, opts, key_setup)
|
||||
dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh'])
|
||||
print "Transferring cluster's SSH key to slaves..."
|
||||
print("Transferring cluster's SSH key to slaves...")
|
||||
for slave in slave_nodes:
|
||||
slave_address = get_dns_name(slave, opts.private_ips)
|
||||
print slave_address
|
||||
print(slave_address)
|
||||
ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar)
|
||||
|
||||
modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs',
|
||||
|
@ -738,8 +742,8 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
|
|||
|
||||
# NOTE: We should clone the repository before running deploy_files to
|
||||
# prevent ec2-variables.sh from being overwritten
|
||||
print "Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format(
|
||||
r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)
|
||||
print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format(
|
||||
r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch))
|
||||
ssh(
|
||||
host=master,
|
||||
opts=opts,
|
||||
|
@ -749,7 +753,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
|
|||
b=opts.spark_ec2_git_branch)
|
||||
)
|
||||
|
||||
print "Deploying files to master..."
|
||||
print("Deploying files to master...")
|
||||
deploy_files(
|
||||
conn=conn,
|
||||
root_dir=SPARK_EC2_DIR + "/" + "deploy.generic",
|
||||
|
@ -760,25 +764,25 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
|
|||
)
|
||||
|
||||
if opts.deploy_root_dir is not None:
|
||||
print "Deploying {s} to master...".format(s=opts.deploy_root_dir)
|
||||
print("Deploying {s} to master...".format(s=opts.deploy_root_dir))
|
||||
deploy_user_files(
|
||||
root_dir=opts.deploy_root_dir,
|
||||
opts=opts,
|
||||
master_nodes=master_nodes
|
||||
)
|
||||
|
||||
print "Running setup on master..."
|
||||
print("Running setup on master...")
|
||||
setup_spark_cluster(master, opts)
|
||||
print "Done!"
|
||||
print("Done!")
|
||||
|
||||
|
||||
def setup_spark_cluster(master, opts):
|
||||
ssh(master, opts, "chmod u+x spark-ec2/setup.sh")
|
||||
ssh(master, opts, "spark-ec2/setup.sh")
|
||||
print "Spark standalone cluster started at http://%s:8080" % master
|
||||
print("Spark standalone cluster started at http://%s:8080" % master)
|
||||
|
||||
if opts.ganglia:
|
||||
print "Ganglia started at http://%s:5080/ganglia" % master
|
||||
print("Ganglia started at http://%s:5080/ganglia" % master)
|
||||
|
||||
|
||||
def is_ssh_available(host, opts, print_ssh_output=True):
|
||||
|
@ -795,7 +799,7 @@ def is_ssh_available(host, opts, print_ssh_output=True):
|
|||
|
||||
if s.returncode != 0 and print_ssh_output:
|
||||
# extra leading newline is for spacing in wait_for_cluster_state()
|
||||
print textwrap.dedent("""\n
|
||||
print(textwrap.dedent("""\n
|
||||
Warning: SSH connection error. (This could be temporary.)
|
||||
Host: {h}
|
||||
SSH return code: {r}
|
||||
|
@ -804,7 +808,7 @@ def is_ssh_available(host, opts, print_ssh_output=True):
|
|||
h=host,
|
||||
r=s.returncode,
|
||||
o=cmd_output.strip()
|
||||
)
|
||||
))
|
||||
|
||||
return s.returncode == 0
|
||||
|
||||
|
@ -865,10 +869,10 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state):
|
|||
sys.stdout.write("\n")
|
||||
|
||||
end_time = datetime.now()
|
||||
print "Cluster is now in '{s}' state. Waited {t} seconds.".format(
|
||||
print("Cluster is now in '{s}' state. Waited {t} seconds.".format(
|
||||
s=cluster_state,
|
||||
t=(end_time - start_time).seconds
|
||||
)
|
||||
))
|
||||
|
||||
|
||||
# Get number of local disks available for a given EC2 instance type.
|
||||
|
@ -916,8 +920,8 @@ def get_num_disks(instance_type):
|
|||
if instance_type in disks_by_instance:
|
||||
return disks_by_instance[instance_type]
|
||||
else:
|
||||
print >> stderr, ("WARNING: Don't know number of disks on instance type %s; assuming 1"
|
||||
% instance_type)
|
||||
print("WARNING: Don't know number of disks on instance type %s; assuming 1"
|
||||
% instance_type, file=stderr)
|
||||
return 1
|
||||
|
||||
|
||||
|
@ -951,7 +955,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
|
|||
# Spark-only custom deploy
|
||||
spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version)
|
||||
tachyon_v = ""
|
||||
print "Deploying Spark via git hash; Tachyon won't be set up"
|
||||
print("Deploying Spark via git hash; Tachyon won't be set up")
|
||||
modules = filter(lambda x: x != "tachyon", modules)
|
||||
|
||||
master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes]
|
||||
|
@ -1067,8 +1071,8 @@ def ssh(host, opts, command):
|
|||
"--key-pair parameters and try again.".format(host))
|
||||
else:
|
||||
raise e
|
||||
print >> stderr, \
|
||||
"Error executing remote command, retrying after 30 seconds: {0}".format(e)
|
||||
print("Error executing remote command, retrying after 30 seconds: {0}".format(e),
|
||||
file=stderr)
|
||||
time.sleep(30)
|
||||
tries = tries + 1
|
||||
|
||||
|
@ -1107,8 +1111,8 @@ def ssh_write(host, opts, command, arguments):
|
|||
elif tries > 5:
|
||||
raise RuntimeError("ssh_write failed with error %s" % proc.returncode)
|
||||
else:
|
||||
print >> stderr, \
|
||||
"Error {0} while executing remote command, retrying after 30 seconds".format(status)
|
||||
print("Error {0} while executing remote command, retrying after 30 seconds".
|
||||
format(status), file=stderr)
|
||||
time.sleep(30)
|
||||
tries = tries + 1
|
||||
|
||||
|
@ -1162,42 +1166,41 @@ def real_main():
|
|||
|
||||
if opts.identity_file is not None:
|
||||
if not os.path.exists(opts.identity_file):
|
||||
print >> stderr,\
|
||||
"ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file)
|
||||
print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file),
|
||||
file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
file_mode = os.stat(opts.identity_file).st_mode
|
||||
if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00':
|
||||
print >> stderr, "ERROR: The identity file must be accessible only by you."
|
||||
print >> stderr, 'You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file)
|
||||
print("ERROR: The identity file must be accessible only by you.", file=stderr)
|
||||
print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file),
|
||||
file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if opts.instance_type not in EC2_INSTANCE_TYPES:
|
||||
print >> stderr, "Warning: Unrecognized EC2 instance type for instance-type: {t}".format(
|
||||
t=opts.instance_type)
|
||||
print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format(
|
||||
t=opts.instance_type), file=stderr)
|
||||
|
||||
if opts.master_instance_type != "":
|
||||
if opts.master_instance_type not in EC2_INSTANCE_TYPES:
|
||||
print >> stderr, \
|
||||
"Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format(
|
||||
t=opts.master_instance_type)
|
||||
print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format(
|
||||
t=opts.master_instance_type), file=stderr)
|
||||
# Since we try instance types even if we can't resolve them, we check if they resolve first
|
||||
# and, if they do, see if they resolve to the same virtualization type.
|
||||
if opts.instance_type in EC2_INSTANCE_TYPES and \
|
||||
opts.master_instance_type in EC2_INSTANCE_TYPES:
|
||||
if EC2_INSTANCE_TYPES[opts.instance_type] != \
|
||||
EC2_INSTANCE_TYPES[opts.master_instance_type]:
|
||||
print >> stderr, \
|
||||
"Error: spark-ec2 currently does not support having a master and slaves " + \
|
||||
"with different AMI virtualization types."
|
||||
print >> stderr, "master instance virtualization type: {t}".format(
|
||||
t=EC2_INSTANCE_TYPES[opts.master_instance_type])
|
||||
print >> stderr, "slave instance virtualization type: {t}".format(
|
||||
t=EC2_INSTANCE_TYPES[opts.instance_type])
|
||||
print("Error: spark-ec2 currently does not support having a master and slaves "
|
||||
"with different AMI virtualization types.", file=stderr)
|
||||
print("master instance virtualization type: {t}".format(
|
||||
t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr)
|
||||
print("slave instance virtualization type: {t}".format(
|
||||
t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if opts.ebs_vol_num > 8:
|
||||
print >> stderr, "ebs-vol-num cannot be greater than 8"
|
||||
print("ebs-vol-num cannot be greater than 8", file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Prevent breaking ami_prefix (/, .git and startswith checks)
|
||||
|
@ -1206,23 +1209,22 @@ def real_main():
|
|||
opts.spark_ec2_git_repo.endswith(".git") or \
|
||||
not opts.spark_ec2_git_repo.startswith("https://github.com") or \
|
||||
not opts.spark_ec2_git_repo.endswith("spark-ec2"):
|
||||
print >> stderr, "spark-ec2-git-repo must be a github repo and it must not have a " \
|
||||
"trailing / or .git. " \
|
||||
"Furthermore, we currently only support forks named spark-ec2."
|
||||
print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. "
|
||||
"Furthermore, we currently only support forks named spark-ec2.", file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not (opts.deploy_root_dir is None or
|
||||
(os.path.isabs(opts.deploy_root_dir) and
|
||||
os.path.isdir(opts.deploy_root_dir) and
|
||||
os.path.exists(opts.deploy_root_dir))):
|
||||
print >> stderr, "--deploy-root-dir must be an absolute path to a directory that exists " \
|
||||
"on the local file system"
|
||||
print("--deploy-root-dir must be an absolute path to a directory that exists "
|
||||
"on the local file system", file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
conn = ec2.connect_to_region(opts.region)
|
||||
except Exception as e:
|
||||
print >> stderr, (e)
|
||||
print((e), file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Select an AZ at random if it was not specified.
|
||||
|
@ -1231,7 +1233,7 @@ def real_main():
|
|||
|
||||
if action == "launch":
|
||||
if opts.slaves <= 0:
|
||||
print >> sys.stderr, "ERROR: You have to start at least 1 slave"
|
||||
print("ERROR: You have to start at least 1 slave", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if opts.resume:
|
||||
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
|
||||
|
@ -1250,18 +1252,18 @@ def real_main():
|
|||
conn, opts, cluster_name, die_on_error=False)
|
||||
|
||||
if any(master_nodes + slave_nodes):
|
||||
print "The following instances will be terminated:"
|
||||
print("The following instances will be terminated:")
|
||||
for inst in master_nodes + slave_nodes:
|
||||
print "> %s" % get_dns_name(inst, opts.private_ips)
|
||||
print "ALL DATA ON ALL NODES WILL BE LOST!!"
|
||||
print("> %s" % get_dns_name(inst, opts.private_ips))
|
||||
print("ALL DATA ON ALL NODES WILL BE LOST!!")
|
||||
|
||||
msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name)
|
||||
response = raw_input(msg)
|
||||
if response == "y":
|
||||
print "Terminating master..."
|
||||
print("Terminating master...")
|
||||
for inst in master_nodes:
|
||||
inst.terminate()
|
||||
print "Terminating slaves..."
|
||||
print("Terminating slaves...")
|
||||
for inst in slave_nodes:
|
||||
inst.terminate()
|
||||
|
||||
|
@ -1274,16 +1276,16 @@ def real_main():
|
|||
cluster_instances=(master_nodes + slave_nodes),
|
||||
cluster_state='terminated'
|
||||
)
|
||||
print "Deleting security groups (this will take some time)..."
|
||||
print("Deleting security groups (this will take some time)...")
|
||||
attempt = 1
|
||||
while attempt <= 3:
|
||||
print "Attempt %d" % attempt
|
||||
print("Attempt %d" % attempt)
|
||||
groups = [g for g in conn.get_all_security_groups() if g.name in group_names]
|
||||
success = True
|
||||
# Delete individual rules in all groups before deleting groups to
|
||||
# remove dependencies between them
|
||||
for group in groups:
|
||||
print "Deleting rules in security group " + group.name
|
||||
print("Deleting rules in security group " + group.name)
|
||||
for rule in group.rules:
|
||||
for grant in rule.grants:
|
||||
success &= group.revoke(ip_protocol=rule.ip_protocol,
|
||||
|
@ -1298,10 +1300,10 @@ def real_main():
|
|||
try:
|
||||
# It is needed to use group_id to make it work with VPC
|
||||
conn.delete_security_group(group_id=group.id)
|
||||
print "Deleted security group %s" % group.name
|
||||
print("Deleted security group %s" % group.name)
|
||||
except boto.exception.EC2ResponseError:
|
||||
success = False
|
||||
print "Failed to delete security group %s" % group.name
|
||||
print("Failed to delete security group %s" % group.name)
|
||||
|
||||
# Unfortunately, group.revoke() returns True even if a rule was not
|
||||
# deleted, so this needs to be rerun if something fails
|
||||
|
@ -1311,17 +1313,16 @@ def real_main():
|
|||
attempt += 1
|
||||
|
||||
if not success:
|
||||
print "Failed to delete all security groups after 3 tries."
|
||||
print "Try re-running in a few minutes."
|
||||
print("Failed to delete all security groups after 3 tries.")
|
||||
print("Try re-running in a few minutes.")
|
||||
|
||||
elif action == "login":
|
||||
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
|
||||
if not master_nodes[0].public_dns_name and not opts.private_ips:
|
||||
print "Master has no public DNS name. Maybe you meant to specify " \
|
||||
"--private-ips?"
|
||||
print("Master has no public DNS name. Maybe you meant to specify --private-ips?")
|
||||
else:
|
||||
master = get_dns_name(master_nodes[0], opts.private_ips)
|
||||
print "Logging into master " + master + "..."
|
||||
print("Logging into master " + master + "...")
|
||||
proxy_opt = []
|
||||
if opts.proxy_port is not None:
|
||||
proxy_opt = ['-D', opts.proxy_port]
|
||||
|
@ -1336,19 +1337,18 @@ def real_main():
|
|||
if response == "y":
|
||||
(master_nodes, slave_nodes) = get_existing_cluster(
|
||||
conn, opts, cluster_name, die_on_error=False)
|
||||
print "Rebooting slaves..."
|
||||
print("Rebooting slaves...")
|
||||
for inst in slave_nodes:
|
||||
if inst.state not in ["shutting-down", "terminated"]:
|
||||
print "Rebooting " + inst.id
|
||||
print("Rebooting " + inst.id)
|
||||
inst.reboot()
|
||||
|
||||
elif action == "get-master":
|
||||
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
|
||||
if not master_nodes[0].public_dns_name and not opts.private_ips:
|
||||
print "Master has no public DNS name. Maybe you meant to specify " \
|
||||
"--private-ips?"
|
||||
print("Master has no public DNS name. Maybe you meant to specify --private-ips?")
|
||||
else:
|
||||
print get_dns_name(master_nodes[0], opts.private_ips)
|
||||
print(get_dns_name(master_nodes[0], opts.private_ips))
|
||||
|
||||
elif action == "stop":
|
||||
response = raw_input(
|
||||
|
@ -1361,11 +1361,11 @@ def real_main():
|
|||
if response == "y":
|
||||
(master_nodes, slave_nodes) = get_existing_cluster(
|
||||
conn, opts, cluster_name, die_on_error=False)
|
||||
print "Stopping master..."
|
||||
print("Stopping master...")
|
||||
for inst in master_nodes:
|
||||
if inst.state not in ["shutting-down", "terminated"]:
|
||||
inst.stop()
|
||||
print "Stopping slaves..."
|
||||
print("Stopping slaves...")
|
||||
for inst in slave_nodes:
|
||||
if inst.state not in ["shutting-down", "terminated"]:
|
||||
if inst.spot_instance_request_id:
|
||||
|
@ -1375,11 +1375,11 @@ def real_main():
|
|||
|
||||
elif action == "start":
|
||||
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
|
||||
print "Starting slaves..."
|
||||
print("Starting slaves...")
|
||||
for inst in slave_nodes:
|
||||
if inst.state not in ["shutting-down", "terminated"]:
|
||||
inst.start()
|
||||
print "Starting master..."
|
||||
print("Starting master...")
|
||||
for inst in master_nodes:
|
||||
if inst.state not in ["shutting-down", "terminated"]:
|
||||
inst.start()
|
||||
|
@ -1403,15 +1403,15 @@ def real_main():
|
|||
setup_cluster(conn, master_nodes, slave_nodes, opts, False)
|
||||
|
||||
else:
|
||||
print >> stderr, "Invalid action: %s" % action
|
||||
print("Invalid action: %s" % action, file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
real_main()
|
||||
except UsageError, e:
|
||||
print >> stderr, "\nError:\n", e
|
||||
except UsageError as e:
|
||||
print("\nError:\n", e, file=stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,8 @@ ALS in pyspark.mllib.recommendation for more conventional use.
|
|||
|
||||
This example requires numpy (http://www.numpy.org/)
|
||||
"""
|
||||
from os.path import realpath
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
@ -57,9 +58,9 @@ if __name__ == "__main__":
|
|||
Usage: als [M] [U] [F] [iterations] [partitions]"
|
||||
"""
|
||||
|
||||
print >> sys.stderr, """WARN: This is a naive implementation of ALS and is given as an
|
||||
print("""WARN: This is a naive implementation of ALS and is given as an
|
||||
example. Please use the ALS method found in pyspark.mllib.recommendation for more
|
||||
conventional use."""
|
||||
conventional use.""", file=sys.stderr)
|
||||
|
||||
sc = SparkContext(appName="PythonALS")
|
||||
M = int(sys.argv[1]) if len(sys.argv) > 1 else 100
|
||||
|
@ -68,8 +69,8 @@ if __name__ == "__main__":
|
|||
ITERATIONS = int(sys.argv[4]) if len(sys.argv) > 4 else 5
|
||||
partitions = int(sys.argv[5]) if len(sys.argv) > 5 else 2
|
||||
|
||||
print "Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % \
|
||||
(M, U, F, ITERATIONS, partitions)
|
||||
print("Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" %
|
||||
(M, U, F, ITERATIONS, partitions))
|
||||
|
||||
R = matrix(rand(M, F)) * matrix(rand(U, F).T)
|
||||
ms = matrix(rand(M, F))
|
||||
|
@ -95,7 +96,7 @@ if __name__ == "__main__":
|
|||
usb = sc.broadcast(us)
|
||||
|
||||
error = rmse(R, ms, us)
|
||||
print "Iteration %d:" % i
|
||||
print "\nRMSE: %5.4f\n" % error
|
||||
print("Iteration %d:" % i)
|
||||
print("\nRMSE: %5.4f\n" % error)
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -15,9 +15,12 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from pyspark import SparkContext
|
||||
from functools import reduce
|
||||
|
||||
"""
|
||||
Read data file users.avro in local Spark distro:
|
||||
|
@ -49,7 +52,7 @@ $ ./bin/spark-submit --driver-class-path /path/to/example/jar \
|
|||
"""
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2 and len(sys.argv) != 3:
|
||||
print >> sys.stderr, """
|
||||
print("""
|
||||
Usage: avro_inputformat <data_file> [reader_schema_file]
|
||||
|
||||
Run with example jar:
|
||||
|
@ -57,7 +60,7 @@ if __name__ == "__main__":
|
|||
/path/to/examples/avro_inputformat.py <data_file> [reader_schema_file]
|
||||
Assumes you have Avro data stored in <data_file>. Reader schema can be optionally specified
|
||||
in [reader_schema_file].
|
||||
"""
|
||||
""", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
path = sys.argv[1]
|
||||
|
@ -77,6 +80,6 @@ if __name__ == "__main__":
|
|||
conf=conf)
|
||||
output = avro_rdd.map(lambda x: x[0]).collect()
|
||||
for k in output:
|
||||
print k
|
||||
print(k)
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from pyspark import SparkContext
|
||||
|
@ -47,14 +49,14 @@ cqlsh:test> SELECT * FROM users;
|
|||
"""
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 4:
|
||||
print >> sys.stderr, """
|
||||
print("""
|
||||
Usage: cassandra_inputformat <host> <keyspace> <cf>
|
||||
|
||||
Run with example jar:
|
||||
./bin/spark-submit --driver-class-path /path/to/example/jar \
|
||||
/path/to/examples/cassandra_inputformat.py <host> <keyspace> <cf>
|
||||
Assumes you have some data in Cassandra already, running on <host>, in <keyspace> and <cf>
|
||||
"""
|
||||
""", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
host = sys.argv[1]
|
||||
|
@ -77,6 +79,6 @@ if __name__ == "__main__":
|
|||
conf=conf)
|
||||
output = cass_rdd.collect()
|
||||
for (k, v) in output:
|
||||
print (k, v)
|
||||
print((k, v))
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from pyspark import SparkContext
|
||||
|
@ -46,7 +48,7 @@ cqlsh:test> SELECT * FROM users;
|
|||
"""
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 7:
|
||||
print >> sys.stderr, """
|
||||
print("""
|
||||
Usage: cassandra_outputformat <host> <keyspace> <cf> <user_id> <fname> <lname>
|
||||
|
||||
Run with example jar:
|
||||
|
@ -60,7 +62,7 @@ if __name__ == "__main__":
|
|||
... fname text,
|
||||
... lname text
|
||||
... );
|
||||
"""
|
||||
""", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
host = sys.argv[1]
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from pyspark import SparkContext
|
||||
|
@ -47,14 +49,14 @@ ROW COLUMN+CELL
|
|||
"""
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print >> sys.stderr, """
|
||||
print("""
|
||||
Usage: hbase_inputformat <host> <table>
|
||||
|
||||
Run with example jar:
|
||||
./bin/spark-submit --driver-class-path /path/to/example/jar \
|
||||
/path/to/examples/hbase_inputformat.py <host> <table>
|
||||
Assumes you have some data in HBase already, running on <host>, in <table>
|
||||
"""
|
||||
""", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
host = sys.argv[1]
|
||||
|
@ -74,6 +76,6 @@ if __name__ == "__main__":
|
|||
conf=conf)
|
||||
output = hbase_rdd.collect()
|
||||
for (k, v) in output:
|
||||
print (k, v)
|
||||
print((k, v))
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from pyspark import SparkContext
|
||||
|
@ -40,7 +42,7 @@ ROW COLUMN+CELL
|
|||
"""
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 7:
|
||||
print >> sys.stderr, """
|
||||
print("""
|
||||
Usage: hbase_outputformat <host> <table> <row> <family> <qualifier> <value>
|
||||
|
||||
Run with example jar:
|
||||
|
@ -48,7 +50,7 @@ if __name__ == "__main__":
|
|||
/path/to/examples/hbase_outputformat.py <args>
|
||||
Assumes you have created <table> with column family <family> in HBase
|
||||
running on <host> already
|
||||
"""
|
||||
""", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
host = sys.argv[1]
|
||||
|
|
|
@ -22,6 +22,7 @@ examples/src/main/python/mllib/kmeans.py.
|
|||
|
||||
This example requires NumPy (http://www.numpy.org/).
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -47,12 +48,12 @@ def closestPoint(p, centers):
|
|||
if __name__ == "__main__":
|
||||
|
||||
if len(sys.argv) != 4:
|
||||
print >> sys.stderr, "Usage: kmeans <file> <k> <convergeDist>"
|
||||
print("Usage: kmeans <file> <k> <convergeDist>", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
print >> sys.stderr, """WARN: This is a naive implementation of KMeans Clustering and is given
|
||||
print("""WARN: This is a naive implementation of KMeans Clustering and is given
|
||||
as an example! Please refer to examples/src/main/python/mllib/kmeans.py for an example on
|
||||
how to use MLlib's KMeans implementation."""
|
||||
how to use MLlib's KMeans implementation.""", file=sys.stderr)
|
||||
|
||||
sc = SparkContext(appName="PythonKMeans")
|
||||
lines = sc.textFile(sys.argv[1])
|
||||
|
@ -69,13 +70,13 @@ if __name__ == "__main__":
|
|||
pointStats = closest.reduceByKey(
|
||||
lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2))
|
||||
newPoints = pointStats.map(
|
||||
lambda (x, (y, z)): (x, y / z)).collect()
|
||||
lambda xy: (xy[0], xy[1][0] / xy[1][1])).collect()
|
||||
|
||||
tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints)
|
||||
|
||||
for (x, y) in newPoints:
|
||||
kPoints[x] = y
|
||||
|
||||
print "Final centers: " + str(kPoints)
|
||||
print("Final centers: " + str(kPoints))
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -22,10 +22,8 @@ to act on batches of input data using efficient matrix operations.
|
|||
In practice, one may prefer to use the LogisticRegression algorithm in
|
||||
MLlib, as shown in examples/src/main/python/mllib/logistic_regression.py.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
from math import exp
|
||||
from os.path import realpath
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
@ -42,19 +40,19 @@ D = 10 # Number of dimensions
|
|||
def readPointBatch(iterator):
|
||||
strs = list(iterator)
|
||||
matrix = np.zeros((len(strs), D + 1))
|
||||
for i in xrange(len(strs)):
|
||||
matrix[i] = np.fromstring(strs[i].replace(',', ' '), dtype=np.float32, sep=' ')
|
||||
for i, s in enumerate(strs):
|
||||
matrix[i] = np.fromstring(s.replace(',', ' '), dtype=np.float32, sep=' ')
|
||||
return [matrix]
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if len(sys.argv) != 3:
|
||||
print >> sys.stderr, "Usage: logistic_regression <file> <iterations>"
|
||||
print("Usage: logistic_regression <file> <iterations>", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
print >> sys.stderr, """WARN: This is a naive implementation of Logistic Regression and is
|
||||
print("""WARN: This is a naive implementation of Logistic Regression and is
|
||||
given as an example! Please refer to examples/src/main/python/mllib/logistic_regression.py
|
||||
to see how MLlib's implementation is used."""
|
||||
to see how MLlib's implementation is used.""", file=sys.stderr)
|
||||
|
||||
sc = SparkContext(appName="PythonLR")
|
||||
points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache()
|
||||
|
@ -62,7 +60,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Initialize w to a random value
|
||||
w = 2 * np.random.ranf(size=D) - 1
|
||||
print "Initial w: " + str(w)
|
||||
print("Initial w: " + str(w))
|
||||
|
||||
# Compute logistic regression gradient for a matrix of data points
|
||||
def gradient(matrix, w):
|
||||
|
@ -76,9 +74,9 @@ if __name__ == "__main__":
|
|||
return x
|
||||
|
||||
for i in range(iterations):
|
||||
print "On iteration %i" % (i + 1)
|
||||
print("On iteration %i" % (i + 1))
|
||||
w -= points.map(lambda m: gradient(m, w)).reduce(add)
|
||||
|
||||
print "Final w: " + str(w)
|
||||
print("Final w: " + str(w))
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from pyspark import SparkContext
|
||||
from pyspark.ml import Pipeline
|
||||
from pyspark.ml.classification import LogisticRegression
|
||||
|
@ -37,10 +39,10 @@ if __name__ == "__main__":
|
|||
|
||||
# Prepare training documents, which are labeled.
|
||||
LabeledDocument = Row("id", "text", "label")
|
||||
training = sc.parallelize([(0L, "a b c d e spark", 1.0),
|
||||
(1L, "b d", 0.0),
|
||||
(2L, "spark f g h", 1.0),
|
||||
(3L, "hadoop mapreduce", 0.0)]) \
|
||||
training = sc.parallelize([(0, "a b c d e spark", 1.0),
|
||||
(1, "b d", 0.0),
|
||||
(2, "spark f g h", 1.0),
|
||||
(3, "hadoop mapreduce", 0.0)]) \
|
||||
.map(lambda x: LabeledDocument(*x)).toDF()
|
||||
|
||||
# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
|
||||
|
@ -54,16 +56,16 @@ if __name__ == "__main__":
|
|||
|
||||
# Prepare test documents, which are unlabeled.
|
||||
Document = Row("id", "text")
|
||||
test = sc.parallelize([(4L, "spark i j k"),
|
||||
(5L, "l m n"),
|
||||
(6L, "mapreduce spark"),
|
||||
(7L, "apache hadoop")]) \
|
||||
test = sc.parallelize([(4, "spark i j k"),
|
||||
(5, "l m n"),
|
||||
(6, "mapreduce spark"),
|
||||
(7, "apache hadoop")]) \
|
||||
.map(lambda x: Document(*x)).toDF()
|
||||
|
||||
# Make predictions on test documents and print columns of interest.
|
||||
prediction = model.transform(test)
|
||||
selected = prediction.select("id", "text", "prediction")
|
||||
for row in selected.collect():
|
||||
print row
|
||||
print(row)
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
"""
|
||||
Correlations using MLlib.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -29,7 +30,7 @@ from pyspark.mllib.util import MLUtils
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) not in [1, 2]:
|
||||
print >> sys.stderr, "Usage: correlations (<file>)"
|
||||
print("Usage: correlations (<file>)", file=sys.stderr)
|
||||
exit(-1)
|
||||
sc = SparkContext(appName="PythonCorrelations")
|
||||
if len(sys.argv) == 2:
|
||||
|
@ -41,20 +42,20 @@ if __name__ == "__main__":
|
|||
points = MLUtils.loadLibSVMFile(sc, filepath)\
|
||||
.map(lambda lp: LabeledPoint(lp.label, lp.features.toArray()))
|
||||
|
||||
print
|
||||
print 'Summary of data file: ' + filepath
|
||||
print '%d data points' % points.count()
|
||||
print()
|
||||
print('Summary of data file: ' + filepath)
|
||||
print('%d data points' % points.count())
|
||||
|
||||
# Statistics (correlations)
|
||||
print
|
||||
print 'Correlation (%s) between label and each feature' % corrType
|
||||
print 'Feature\tCorrelation'
|
||||
print()
|
||||
print('Correlation (%s) between label and each feature' % corrType)
|
||||
print('Feature\tCorrelation')
|
||||
numFeatures = points.take(1)[0].features.size
|
||||
labelRDD = points.map(lambda lp: lp.label)
|
||||
for i in range(numFeatures):
|
||||
featureRDD = points.map(lambda lp: lp.features[i])
|
||||
corr = Statistics.corr(labelRDD, featureRDD, corrType)
|
||||
print '%d\t%g' % (i, corr)
|
||||
print
|
||||
print('%d\t%g' % (i, corr))
|
||||
print()
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
An example of how to use DataFrame as a dataset for ML. Run with::
|
||||
bin/spark-submit examples/src/main/python/mllib/dataset_example.py
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
@ -32,16 +33,16 @@ from pyspark.mllib.stat import Statistics
|
|||
|
||||
|
||||
def summarize(dataset):
|
||||
print "schema: %s" % dataset.schema().json()
|
||||
print("schema: %s" % dataset.schema().json())
|
||||
labels = dataset.map(lambda r: r.label)
|
||||
print "label average: %f" % labels.mean()
|
||||
print("label average: %f" % labels.mean())
|
||||
features = dataset.map(lambda r: r.features)
|
||||
summary = Statistics.colStats(features)
|
||||
print "features average: %r" % summary.mean()
|
||||
print("features average: %r" % summary.mean())
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 2:
|
||||
print >> sys.stderr, "Usage: dataset_example.py <libsvm file>"
|
||||
print("Usage: dataset_example.py <libsvm file>", file=sys.stderr)
|
||||
exit(-1)
|
||||
sc = SparkContext(appName="DatasetExample")
|
||||
sqlContext = SQLContext(sc)
|
||||
|
@ -54,9 +55,9 @@ if __name__ == "__main__":
|
|||
summarize(dataset0)
|
||||
tempdir = tempfile.NamedTemporaryFile(delete=False).name
|
||||
os.unlink(tempdir)
|
||||
print "Save dataset as a Parquet file to %s." % tempdir
|
||||
print("Save dataset as a Parquet file to %s." % tempdir)
|
||||
dataset0.saveAsParquetFile(tempdir)
|
||||
print "Load it back and summarize it again."
|
||||
print("Load it back and summarize it again.")
|
||||
dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache()
|
||||
summarize(dataset1)
|
||||
shutil.rmtree(tempdir)
|
||||
|
|
|
@ -20,6 +20,7 @@ Decision tree classification and regression using MLlib.
|
|||
|
||||
This example requires NumPy (http://www.numpy.org/).
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy
|
||||
import os
|
||||
|
@ -83,18 +84,17 @@ def reindexClassLabels(data):
|
|||
numClasses = len(classCounts)
|
||||
# origToNewLabels: class --> index in 0,...,numClasses-1
|
||||
if (numClasses < 2):
|
||||
print >> sys.stderr, \
|
||||
"Dataset for classification should have at least 2 classes." + \
|
||||
" The given dataset had only %d classes." % numClasses
|
||||
print("Dataset for classification should have at least 2 classes."
|
||||
" The given dataset had only %d classes." % numClasses, file=sys.stderr)
|
||||
exit(1)
|
||||
origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)])
|
||||
|
||||
print "numClasses = %d" % numClasses
|
||||
print "Per-class example fractions, counts:"
|
||||
print "Class\tFrac\tCount"
|
||||
print("numClasses = %d" % numClasses)
|
||||
print("Per-class example fractions, counts:")
|
||||
print("Class\tFrac\tCount")
|
||||
for c in sortedClasses:
|
||||
frac = classCounts[c] / (numExamples + 0.0)
|
||||
print "%g\t%g\t%d" % (c, frac, classCounts[c])
|
||||
print("%g\t%g\t%d" % (c, frac, classCounts[c]))
|
||||
|
||||
if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1):
|
||||
return (data, origToNewLabels)
|
||||
|
@ -105,8 +105,7 @@ def reindexClassLabels(data):
|
|||
|
||||
|
||||
def usage():
|
||||
print >> sys.stderr, \
|
||||
"Usage: decision_tree_runner [libsvm format data filepath]"
|
||||
print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
|
||||
|
@ -133,13 +132,13 @@ if __name__ == "__main__":
|
|||
model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses,
|
||||
categoricalFeaturesInfo=categoricalFeaturesInfo)
|
||||
# Print learned tree and stats.
|
||||
print "Trained DecisionTree for classification:"
|
||||
print " Model numNodes: %d" % model.numNodes()
|
||||
print " Model depth: %d" % model.depth()
|
||||
print " Training accuracy: %g" % getAccuracy(model, reindexedData)
|
||||
print("Trained DecisionTree for classification:")
|
||||
print(" Model numNodes: %d" % model.numNodes())
|
||||
print(" Model depth: %d" % model.depth())
|
||||
print(" Training accuracy: %g" % getAccuracy(model, reindexedData))
|
||||
if model.numNodes() < 20:
|
||||
print model.toDebugString()
|
||||
print(model.toDebugString())
|
||||
else:
|
||||
print model
|
||||
print(model)
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -18,7 +18,8 @@
|
|||
"""
|
||||
A Gaussian Mixture Model clustering program using MLlib.
|
||||
"""
|
||||
import sys
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
@ -59,7 +60,7 @@ if __name__ == "__main__":
|
|||
model = GaussianMixture.train(data, args.k, args.convergenceTol,
|
||||
args.maxIterations, args.seed)
|
||||
for i in range(args.k):
|
||||
print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
|
||||
"sigma = ", model.gaussians[i].sigma.toArray())
|
||||
print ("Cluster labels (first 100): ", model.predict(data).take(100))
|
||||
print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
|
||||
"sigma = ", model.gaussians[i].sigma.toArray()))
|
||||
print(("Cluster labels (first 100): ", model.predict(data).take(100)))
|
||||
sc.stop()
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
"""
|
||||
Gradient boosted Trees classification and regression using MLlib.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -34,7 +35,7 @@ def testClassification(trainingData, testData):
|
|||
# Evaluate model on test instances and compute test error
|
||||
predictions = model.predict(testData.map(lambda x: x.features))
|
||||
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
|
||||
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \
|
||||
testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \
|
||||
/ float(testData.count())
|
||||
print('Test Error = ' + str(testErr))
|
||||
print('Learned classification ensemble model:')
|
||||
|
@ -49,7 +50,7 @@ def testRegression(trainingData, testData):
|
|||
# Evaluate model on test instances and compute test error
|
||||
predictions = model.predict(testData.map(lambda x: x.features))
|
||||
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
|
||||
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \
|
||||
testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \
|
||||
/ float(testData.count())
|
||||
print('Test Mean Squared Error = ' + str(testMSE))
|
||||
print('Learned regression ensemble model:')
|
||||
|
@ -58,7 +59,7 @@ def testRegression(trainingData, testData):
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
print >> sys.stderr, "Usage: gradient_boosted_trees"
|
||||
print("Usage: gradient_boosted_trees", file=sys.stderr)
|
||||
exit(1)
|
||||
sc = SparkContext(appName="PythonGradientBoostedTrees")
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ A K-means clustering program using MLlib.
|
|||
|
||||
This example requires NumPy (http://www.numpy.org/).
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -34,12 +35,12 @@ def parseVector(line):
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print >> sys.stderr, "Usage: kmeans <file> <k>"
|
||||
print("Usage: kmeans <file> <k>", file=sys.stderr)
|
||||
exit(-1)
|
||||
sc = SparkContext(appName="KMeans")
|
||||
lines = sc.textFile(sys.argv[1])
|
||||
data = lines.map(parseVector)
|
||||
k = int(sys.argv[2])
|
||||
model = KMeans.train(data, k)
|
||||
print "Final centers: " + str(model.clusterCenters)
|
||||
print("Final centers: " + str(model.clusterCenters))
|
||||
sc.stop()
|
||||
|
|
|
@ -20,11 +20,10 @@ Logistic regression using MLlib.
|
|||
|
||||
This example requires NumPy (http://www.numpy.org/).
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
from math import exp
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from pyspark import SparkContext
|
||||
from pyspark.mllib.regression import LabeledPoint
|
||||
from pyspark.mllib.classification import LogisticRegressionWithSGD
|
||||
|
@ -42,12 +41,12 @@ def parsePoint(line):
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print >> sys.stderr, "Usage: logistic_regression <file> <iterations>"
|
||||
print("Usage: logistic_regression <file> <iterations>", file=sys.stderr)
|
||||
exit(-1)
|
||||
sc = SparkContext(appName="PythonLR")
|
||||
points = sc.textFile(sys.argv[1]).map(parsePoint)
|
||||
iterations = int(sys.argv[2])
|
||||
model = LogisticRegressionWithSGD.train(points, iterations)
|
||||
print "Final weights: " + str(model.weights)
|
||||
print "Final intercept: " + str(model.intercept)
|
||||
print("Final weights: " + str(model.weights))
|
||||
print("Final intercept: " + str(model.intercept))
|
||||
sc.stop()
|
||||
|
|
|
@ -22,6 +22,7 @@ Note: This example illustrates binary classification.
|
|||
For information on multiclass classification, please refer to the decision_tree_runner.py
|
||||
example.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -43,7 +44,7 @@ def testClassification(trainingData, testData):
|
|||
# Evaluate model on test instances and compute test error
|
||||
predictions = model.predict(testData.map(lambda x: x.features))
|
||||
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
|
||||
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()\
|
||||
testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\
|
||||
/ float(testData.count())
|
||||
print('Test Error = ' + str(testErr))
|
||||
print('Learned classification forest model:')
|
||||
|
@ -62,8 +63,8 @@ def testRegression(trainingData, testData):
|
|||
# Evaluate model on test instances and compute test error
|
||||
predictions = model.predict(testData.map(lambda x: x.features))
|
||||
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
|
||||
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()\
|
||||
/ float(testData.count())
|
||||
testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\
|
||||
.sum() / float(testData.count())
|
||||
print('Test Mean Squared Error = ' + str(testMSE))
|
||||
print('Learned regression forest model:')
|
||||
print(model.toDebugString())
|
||||
|
@ -71,7 +72,7 @@ def testRegression(trainingData, testData):
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
print >> sys.stderr, "Usage: random_forest_example"
|
||||
print("Usage: random_forest_example", file=sys.stderr)
|
||||
exit(1)
|
||||
sc = SparkContext(appName="PythonRandomForestExample")
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
"""
|
||||
Randomly generated RDDs.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -27,7 +28,7 @@ from pyspark.mllib.random import RandomRDDs
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) not in [1, 2]:
|
||||
print >> sys.stderr, "Usage: random_rdd_generation"
|
||||
print("Usage: random_rdd_generation", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
sc = SparkContext(appName="PythonRandomRDDGeneration")
|
||||
|
@ -37,19 +38,19 @@ if __name__ == "__main__":
|
|||
|
||||
# Example: RandomRDDs.normalRDD
|
||||
normalRDD = RandomRDDs.normalRDD(sc, numExamples)
|
||||
print 'Generated RDD of %d examples sampled from the standard normal distribution'\
|
||||
% normalRDD.count()
|
||||
print ' First 5 samples:'
|
||||
print('Generated RDD of %d examples sampled from the standard normal distribution'
|
||||
% normalRDD.count())
|
||||
print(' First 5 samples:')
|
||||
for sample in normalRDD.take(5):
|
||||
print ' ' + str(sample)
|
||||
print
|
||||
print(' ' + str(sample))
|
||||
print()
|
||||
|
||||
# Example: RandomRDDs.normalVectorRDD
|
||||
normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows=numExamples, numCols=2)
|
||||
print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count()
|
||||
print ' First 5 samples:'
|
||||
print('Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count())
|
||||
print(' First 5 samples:')
|
||||
for sample in normalVectorRDD.take(5):
|
||||
print ' ' + str(sample)
|
||||
print
|
||||
print(' ' + str(sample))
|
||||
print()
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
"""
|
||||
Randomly sampled RDDs.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -27,7 +28,7 @@ from pyspark.mllib.util import MLUtils
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) not in [1, 2]:
|
||||
print >> sys.stderr, "Usage: sampled_rdds <libsvm data file>"
|
||||
print("Usage: sampled_rdds <libsvm data file>", file=sys.stderr)
|
||||
exit(-1)
|
||||
if len(sys.argv) == 2:
|
||||
datapath = sys.argv[1]
|
||||
|
@ -41,24 +42,24 @@ if __name__ == "__main__":
|
|||
examples = MLUtils.loadLibSVMFile(sc, datapath)
|
||||
numExamples = examples.count()
|
||||
if numExamples == 0:
|
||||
print >> sys.stderr, "Error: Data file had no samples to load."
|
||||
print("Error: Data file had no samples to load.", file=sys.stderr)
|
||||
exit(1)
|
||||
print 'Loaded data with %d examples from file: %s' % (numExamples, datapath)
|
||||
print('Loaded data with %d examples from file: %s' % (numExamples, datapath))
|
||||
|
||||
# Example: RDD.sample() and RDD.takeSample()
|
||||
expectedSampleSize = int(numExamples * fraction)
|
||||
print 'Sampling RDD using fraction %g. Expected sample size = %d.' \
|
||||
% (fraction, expectedSampleSize)
|
||||
print('Sampling RDD using fraction %g. Expected sample size = %d.'
|
||||
% (fraction, expectedSampleSize))
|
||||
sampledRDD = examples.sample(withReplacement=True, fraction=fraction)
|
||||
print ' RDD.sample(): sample has %d examples' % sampledRDD.count()
|
||||
print(' RDD.sample(): sample has %d examples' % sampledRDD.count())
|
||||
sampledArray = examples.takeSample(withReplacement=True, num=expectedSampleSize)
|
||||
print ' RDD.takeSample(): sample has %d examples' % len(sampledArray)
|
||||
print(' RDD.takeSample(): sample has %d examples' % len(sampledArray))
|
||||
|
||||
print
|
||||
print()
|
||||
|
||||
# Example: RDD.sampleByKey()
|
||||
keyedRDD = examples.map(lambda lp: (int(lp.label), lp.features))
|
||||
print ' Keyed data using label (Int) as key ==> Orig'
|
||||
print(' Keyed data using label (Int) as key ==> Orig')
|
||||
# Count examples per label in original data.
|
||||
keyCountsA = keyedRDD.countByKey()
|
||||
|
||||
|
@ -69,18 +70,18 @@ if __name__ == "__main__":
|
|||
sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement=True, fractions=fractions)
|
||||
keyCountsB = sampledByKeyRDD.countByKey()
|
||||
sizeB = sum(keyCountsB.values())
|
||||
print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \
|
||||
% sizeB
|
||||
print(' Sampled %d examples using approximate stratified sampling (by label). ==> Sample'
|
||||
% sizeB)
|
||||
|
||||
# Compare samples
|
||||
print ' \tFractions of examples with key'
|
||||
print 'Key\tOrig\tSample'
|
||||
print(' \tFractions of examples with key')
|
||||
print('Key\tOrig\tSample')
|
||||
for k in sorted(keyCountsA.keys()):
|
||||
fracA = keyCountsA[k] / float(numExamples)
|
||||
if sizeB != 0:
|
||||
fracB = keyCountsB.get(k, 0) / float(sizeB)
|
||||
else:
|
||||
fracB = 0
|
||||
print '%d\t%g\t%g' % (k, fracA, fracB)
|
||||
print('%d\t%g\t%g' % (k, fracA, fracB))
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
# grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines
|
||||
# This was done so that the example can be run in local mode
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -34,7 +35,7 @@ USAGE = ("bin/spark-submit --driver-memory 4g "
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print USAGE
|
||||
print(USAGE)
|
||||
sys.exit("Argument for file not provided")
|
||||
file_path = sys.argv[1]
|
||||
sc = SparkContext(appName='Word2Vec')
|
||||
|
@ -46,5 +47,5 @@ if __name__ == "__main__":
|
|||
synonyms = model.findSynonyms('china', 40)
|
||||
|
||||
for word, cosine_distance in synonyms:
|
||||
print "{}: {}".format(word, cosine_distance)
|
||||
print("{}: {}".format(word, cosine_distance))
|
||||
sc.stop()
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
This is an example implementation of PageRank. For more conventional use,
|
||||
Please refer to PageRank implementation provided by graphx
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
@ -42,11 +43,12 @@ def parseNeighbors(urls):
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print >> sys.stderr, "Usage: pagerank <file> <iterations>"
|
||||
print("Usage: pagerank <file> <iterations>", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is
|
||||
given as an example! Please refer to PageRank implementation provided by graphx"""
|
||||
print("""WARN: This is a naive implementation of PageRank and is
|
||||
given as an example! Please refer to PageRank implementation provided by graphx""",
|
||||
file=sys.stderr)
|
||||
|
||||
# Initialize the spark context.
|
||||
sc = SparkContext(appName="PythonPageRank")
|
||||
|
@ -62,19 +64,19 @@ if __name__ == "__main__":
|
|||
links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache()
|
||||
|
||||
# Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one.
|
||||
ranks = links.map(lambda (url, neighbors): (url, 1.0))
|
||||
ranks = links.map(lambda url_neighbors: (url_neighbors[0], 1.0))
|
||||
|
||||
# Calculates and updates URL ranks continuously using PageRank algorithm.
|
||||
for iteration in xrange(int(sys.argv[2])):
|
||||
for iteration in range(int(sys.argv[2])):
|
||||
# Calculates URL contributions to the rank of other URLs.
|
||||
contribs = links.join(ranks).flatMap(
|
||||
lambda (url, (urls, rank)): computeContribs(urls, rank))
|
||||
lambda url_urls_rank: computeContribs(url_urls_rank[1][0], url_urls_rank[1][1]))
|
||||
|
||||
# Re-calculates URL ranks based on neighbor contributions.
|
||||
ranks = contribs.reduceByKey(add).mapValues(lambda rank: rank * 0.85 + 0.15)
|
||||
|
||||
# Collects all URL ranks and dump them to console.
|
||||
for (link, rank) in ranks.collect():
|
||||
print "%s has rank: %s." % (link, rank)
|
||||
print("%s has rank: %s." % (link, rank))
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import print_function
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
|
@ -35,14 +36,14 @@ $ ./bin/spark-submit --driver-class-path /path/to/example/jar \\
|
|||
"""
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print >> sys.stderr, """
|
||||
print("""
|
||||
Usage: parquet_inputformat.py <data_file>
|
||||
|
||||
Run with example jar:
|
||||
./bin/spark-submit --driver-class-path /path/to/example/jar \\
|
||||
/path/to/examples/parquet_inputformat.py <data_file>
|
||||
Assumes you have Parquet data stored in <data_file>.
|
||||
"""
|
||||
""", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
path = sys.argv[1]
|
||||
|
@ -56,6 +57,6 @@ if __name__ == "__main__":
|
|||
valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter')
|
||||
output = parquet_rdd.map(lambda x: x[1]).collect()
|
||||
for k in output:
|
||||
print k
|
||||
print(k)
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import print_function
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
|
@ -35,7 +36,7 @@ if __name__ == "__main__":
|
|||
y = random() * 2 - 1
|
||||
return 1 if x ** 2 + y ** 2 < 1 else 0
|
||||
|
||||
count = sc.parallelize(xrange(1, n + 1), partitions).map(f).reduce(add)
|
||||
print "Pi is roughly %f" % (4.0 * count / n)
|
||||
count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
|
||||
print("Pi is roughly %f" % (4.0 * count / n))
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from pyspark import SparkContext
|
||||
|
@ -22,7 +24,7 @@ from pyspark import SparkContext
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print >> sys.stderr, "Usage: sort <file>"
|
||||
print("Usage: sort <file>", file=sys.stderr)
|
||||
exit(-1)
|
||||
sc = SparkContext(appName="PythonSort")
|
||||
lines = sc.textFile(sys.argv[1], 1)
|
||||
|
@ -33,6 +35,6 @@ if __name__ == "__main__":
|
|||
# In reality, we wouldn't want to collect all the data to the driver node.
|
||||
output = sortedCount.collect()
|
||||
for (num, unitcount) in output:
|
||||
print num
|
||||
print(num)
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from pyspark import SparkContext
|
||||
|
@ -68,6 +70,6 @@ if __name__ == "__main__":
|
|||
teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
|
||||
|
||||
for each in teenagers.collect():
|
||||
print each[0]
|
||||
print(each[0])
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import threading
|
||||
import Queue
|
||||
|
@ -52,15 +54,15 @@ def main():
|
|||
ids = status.getJobIdsForGroup()
|
||||
for id in ids:
|
||||
job = status.getJobInfo(id)
|
||||
print "Job", id, "status: ", job.status
|
||||
print("Job", id, "status: ", job.status)
|
||||
for sid in job.stageIds:
|
||||
info = status.getStageInfo(sid)
|
||||
if info:
|
||||
print "Stage %d: %d tasks total (%d active, %d complete)" % \
|
||||
(sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks)
|
||||
print("Stage %d: %d tasks total (%d active, %d complete)" %
|
||||
(sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks))
|
||||
time.sleep(1)
|
||||
|
||||
print "Job results are:", result.get()
|
||||
print("Job results are:", result.get())
|
||||
sc.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
Then create a text file in `localdir` and the words in the file will get counted.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -33,7 +34,7 @@ from pyspark.streaming import StreamingContext
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print >> sys.stderr, "Usage: hdfs_wordcount.py <directory>"
|
||||
print("Usage: hdfs_wordcount.py <directory>", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
sc = SparkContext(appName="PythonStreamingHDFSWordCount")
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \
|
||||
localhost:2181 test`
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -36,7 +37,7 @@ from pyspark.streaming.kafka import KafkaUtils
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print >> sys.stderr, "Usage: kafka_wordcount.py <zk> <topic>"
|
||||
print("Usage: kafka_wordcount.py <zk> <topic>", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
sc = SparkContext(appName="PythonStreamingKafkaWordCount")
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
and then run the example
|
||||
`$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999`
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -33,7 +34,7 @@ from pyspark.streaming import StreamingContext
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print >> sys.stderr, "Usage: network_wordcount.py <hostname> <port>"
|
||||
print("Usage: network_wordcount.py <hostname> <port>", file=sys.stderr)
|
||||
exit(-1)
|
||||
sc = SparkContext(appName="PythonStreamingNetworkWordCount")
|
||||
ssc = StreamingContext(sc, 1)
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from
|
||||
the checkpoint data.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
@ -46,7 +47,7 @@ from pyspark.streaming import StreamingContext
|
|||
def createContext(host, port, outputPath):
|
||||
# If you do not see this printed, that means the StreamingContext has been loaded
|
||||
# from the new checkpoint
|
||||
print "Creating new context"
|
||||
print("Creating new context")
|
||||
if os.path.exists(outputPath):
|
||||
os.remove(outputPath)
|
||||
sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount")
|
||||
|
@ -60,8 +61,8 @@ def createContext(host, port, outputPath):
|
|||
|
||||
def echo(time, rdd):
|
||||
counts = "Counts at time %s %s" % (time, rdd.collect())
|
||||
print counts
|
||||
print "Appending to " + os.path.abspath(outputPath)
|
||||
print(counts)
|
||||
print("Appending to " + os.path.abspath(outputPath))
|
||||
with open(outputPath, 'a') as f:
|
||||
f.write(counts + "\n")
|
||||
|
||||
|
@ -70,8 +71,8 @@ def createContext(host, port, outputPath):
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 5:
|
||||
print >> sys.stderr, "Usage: recoverable_network_wordcount.py <hostname> <port> "\
|
||||
"<checkpoint-directory> <output-file>"
|
||||
print("Usage: recoverable_network_wordcount.py <hostname> <port> "
|
||||
"<checkpoint-directory> <output-file>", file=sys.stderr)
|
||||
exit(-1)
|
||||
host, port, checkpoint, output = sys.argv[1:]
|
||||
ssc = StreamingContext.getOrCreate(checkpoint,
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
and then run the example
|
||||
`$ bin/spark-submit examples/src/main/python/streaming/sql_network_wordcount.py localhost 9999`
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
@ -44,7 +45,7 @@ def getSqlContextInstance(sparkContext):
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print >> sys.stderr, "Usage: sql_network_wordcount.py <hostname> <port> "
|
||||
print("Usage: sql_network_wordcount.py <hostname> <port> ", file=sys.stderr)
|
||||
exit(-1)
|
||||
host, port = sys.argv[1:]
|
||||
sc = SparkContext(appName="PythonSqlNetworkWordCount")
|
||||
|
@ -57,7 +58,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Convert RDDs of the words DStream to DataFrame and run SQL query
|
||||
def process(time, rdd):
|
||||
print "========= %s =========" % str(time)
|
||||
print("========= %s =========" % str(time))
|
||||
|
||||
try:
|
||||
# Get the singleton instance of SQLContext
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
`$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \
|
||||
localhost 9999`
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -37,7 +38,7 @@ from pyspark.streaming import StreamingContext
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print >> sys.stderr, "Usage: stateful_network_wordcount.py <hostname> <port>"
|
||||
print("Usage: stateful_network_wordcount.py <hostname> <port>", file=sys.stderr)
|
||||
exit(-1)
|
||||
sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount")
|
||||
ssc = StreamingContext(sc, 1)
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
from random import Random
|
||||
|
||||
|
@ -49,20 +51,20 @@ if __name__ == "__main__":
|
|||
# the graph to obtain the path (x, z).
|
||||
|
||||
# Because join() joins on keys, the edges are stored in reversed order.
|
||||
edges = tc.map(lambda (x, y): (y, x))
|
||||
edges = tc.map(lambda x_y: (x_y[1], x_y[0]))
|
||||
|
||||
oldCount = 0L
|
||||
oldCount = 0
|
||||
nextCount = tc.count()
|
||||
while True:
|
||||
oldCount = nextCount
|
||||
# Perform the join, obtaining an RDD of (y, (z, x)) pairs,
|
||||
# then project the result to obtain the new (x, z) paths.
|
||||
new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a))
|
||||
new_edges = tc.join(edges).map(lambda __a_b: (__a_b[1][1], __a_b[1][0]))
|
||||
tc = tc.union(new_edges).distinct().cache()
|
||||
nextCount = tc.count()
|
||||
if nextCount == oldCount:
|
||||
break
|
||||
|
||||
print "TC has %i edges" % tc.count()
|
||||
print("TC has %i edges" % tc.count())
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
from operator import add
|
||||
|
||||
|
@ -23,7 +25,7 @@ from pyspark import SparkContext
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print >> sys.stderr, "Usage: wordcount <file>"
|
||||
print("Usage: wordcount <file>", file=sys.stderr)
|
||||
exit(-1)
|
||||
sc = SparkContext(appName="PythonWordCount")
|
||||
lines = sc.textFile(sys.argv[1], 1)
|
||||
|
@ -32,6 +34,6 @@ if __name__ == "__main__":
|
|||
.reduceByKey(add)
|
||||
output = counts.collect()
|
||||
for (word, count) in output:
|
||||
print "%s: %i" % (word, count)
|
||||
print("%s: %i" % (word, count))
|
||||
|
||||
sc.stop()
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.mllib.api.python
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
|
@ -31,10 +32,14 @@ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorization
|
|||
predict(SerDe.asTupleRDD(userAndProducts.rdd))
|
||||
|
||||
def getUserFeatures: RDD[Array[Any]] = {
|
||||
SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
|
||||
SerDe.fromTuple2RDD(userFeatures.map {
|
||||
case (user, feature) => (user, Vectors.dense(feature))
|
||||
}.asInstanceOf[RDD[(Any, Any)]])
|
||||
}
|
||||
|
||||
def getProductFeatures: RDD[Array[Any]] = {
|
||||
SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
|
||||
SerDe.fromTuple2RDD(productFeatures.map {
|
||||
case (product, feature) => (product, Vectors.dense(feature))
|
||||
}.asInstanceOf[RDD[(Any, Any)]])
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,7 +28,6 @@ import scala.reflect.ClassTag
|
|||
|
||||
import net.razorvine.pickle._
|
||||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
|
||||
import org.apache.spark.api.python.SerDeUtil
|
||||
import org.apache.spark.mllib.classification._
|
||||
|
@ -40,15 +39,15 @@ import org.apache.spark.mllib.optimization._
|
|||
import org.apache.spark.mllib.random.{RandomRDDs => RG}
|
||||
import org.apache.spark.mllib.recommendation._
|
||||
import org.apache.spark.mllib.regression._
|
||||
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
|
||||
import org.apache.spark.mllib.stat.correlation.CorrelationNames
|
||||
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
|
||||
import org.apache.spark.mllib.stat.test.ChiSqTestResult
|
||||
import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree}
|
||||
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy}
|
||||
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
|
||||
import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
|
||||
import org.apache.spark.mllib.tree.impurity._
|
||||
import org.apache.spark.mllib.tree.loss.Losses
|
||||
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel}
|
||||
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel}
|
||||
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
@ -279,7 +278,7 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
data: JavaRDD[LabeledPoint],
|
||||
lambda: Double): JList[Object] = {
|
||||
val model = NaiveBayes.train(data.rdd, lambda)
|
||||
List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta).
|
||||
List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta.map(Vectors.dense)).
|
||||
map(_.asInstanceOf[Object]).asJava
|
||||
}
|
||||
|
||||
|
@ -335,7 +334,7 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
mu += model.gaussians(i).mu
|
||||
sigma += model.gaussians(i).sigma
|
||||
}
|
||||
List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
|
||||
List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
|
||||
} finally {
|
||||
data.rdd.unpersist(blocking = false)
|
||||
}
|
||||
|
@ -346,20 +345,20 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
*/
|
||||
def predictSoftGMM(
|
||||
data: JavaRDD[Vector],
|
||||
wt: Object,
|
||||
wt: Vector,
|
||||
mu: Array[Object],
|
||||
si: Array[Object]): RDD[Array[Double]] = {
|
||||
si: Array[Object]): RDD[Vector] = {
|
||||
|
||||
val weight = wt.asInstanceOf[Array[Double]]
|
||||
val weight = wt.toArray
|
||||
val mean = mu.map(_.asInstanceOf[DenseVector])
|
||||
val sigma = si.map(_.asInstanceOf[DenseMatrix])
|
||||
val gaussians = Array.tabulate(weight.length){
|
||||
i => new MultivariateGaussian(mean(i), sigma(i))
|
||||
}
|
||||
val model = new GaussianMixtureModel(weight, gaussians)
|
||||
model.predictSoft(data)
|
||||
model.predictSoft(data).map(Vectors.dense)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Java stub for Python mllib ALS.train(). This stub returns a handle
|
||||
* to the Java object instead of the content of the Java object. Extra care
|
||||
|
@ -936,6 +935,14 @@ private[spark] object SerDe extends Serializable {
|
|||
out.write(code)
|
||||
}
|
||||
|
||||
protected def getBytes(obj: Object): Array[Byte] = {
|
||||
if (obj.getClass.isArray) {
|
||||
obj.asInstanceOf[Array[Byte]]
|
||||
} else {
|
||||
obj.asInstanceOf[String].getBytes(LATIN1)
|
||||
}
|
||||
}
|
||||
|
||||
private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler)
|
||||
}
|
||||
|
||||
|
@ -961,7 +968,7 @@ private[spark] object SerDe extends Serializable {
|
|||
if (args.length != 1) {
|
||||
throw new PickleException("should be 1")
|
||||
}
|
||||
val bytes = args(0).asInstanceOf[String].getBytes(LATIN1)
|
||||
val bytes = getBytes(args(0))
|
||||
val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
|
||||
bb.order(ByteOrder.nativeOrder())
|
||||
val db = bb.asDoubleBuffer()
|
||||
|
@ -994,7 +1001,7 @@ private[spark] object SerDe extends Serializable {
|
|||
if (args.length != 3) {
|
||||
throw new PickleException("should be 3")
|
||||
}
|
||||
val bytes = args(2).asInstanceOf[String].getBytes(LATIN1)
|
||||
val bytes = getBytes(args(2))
|
||||
val n = bytes.length / 8
|
||||
val values = new Array[Double](n)
|
||||
val order = ByteOrder.nativeOrder()
|
||||
|
@ -1031,8 +1038,8 @@ private[spark] object SerDe extends Serializable {
|
|||
throw new PickleException("should be 3")
|
||||
}
|
||||
val size = args(0).asInstanceOf[Int]
|
||||
val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1)
|
||||
val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1)
|
||||
val indiceBytes = getBytes(args(1))
|
||||
val valueBytes = getBytes(args(2))
|
||||
val n = indiceBytes.length / 4
|
||||
val indices = new Array[Int](n)
|
||||
val values = new Array[Double](n)
|
||||
|
|
|
@ -54,7 +54,7 @@
|
|||
... def zero(self, value):
|
||||
... return [0.0] * len(value)
|
||||
... def addInPlace(self, val1, val2):
|
||||
... for i in xrange(len(val1)):
|
||||
... for i in range(len(val1)):
|
||||
... val1[i] += val2[i]
|
||||
... return val1
|
||||
>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
|
||||
|
@ -86,9 +86,13 @@ Traceback (most recent call last):
|
|||
Exception:...
|
||||
"""
|
||||
|
||||
import sys
|
||||
import select
|
||||
import struct
|
||||
import SocketServer
|
||||
if sys.version < '3':
|
||||
import SocketServer
|
||||
else:
|
||||
import socketserver as SocketServer
|
||||
import threading
|
||||
from pyspark.cloudpickle import CloudPickler
|
||||
from pyspark.serializers import read_int, PickleSerializer
|
||||
|
@ -247,6 +251,7 @@ class AccumulatorServer(SocketServer.TCPServer):
|
|||
def shutdown(self):
|
||||
self.server_shutdown = True
|
||||
SocketServer.TCPServer.shutdown(self)
|
||||
self.server_close()
|
||||
|
||||
|
||||
def _start_update_server():
|
||||
|
|
|
@ -16,10 +16,15 @@
|
|||
#
|
||||
|
||||
import os
|
||||
import cPickle
|
||||
import sys
|
||||
import gc
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
if sys.version < '3':
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
unicode = str
|
||||
|
||||
__all__ = ['Broadcast']
|
||||
|
||||
|
@ -70,33 +75,19 @@ class Broadcast(object):
|
|||
self._path = path
|
||||
|
||||
def dump(self, value, f):
|
||||
if isinstance(value, basestring):
|
||||
if isinstance(value, unicode):
|
||||
f.write('U')
|
||||
value = value.encode('utf8')
|
||||
else:
|
||||
f.write('S')
|
||||
f.write(value)
|
||||
else:
|
||||
f.write('P')
|
||||
cPickle.dump(value, f, 2)
|
||||
pickle.dump(value, f, 2)
|
||||
f.close()
|
||||
return f.name
|
||||
|
||||
def load(self, path):
|
||||
with open(path, 'rb', 1 << 20) as f:
|
||||
flag = f.read(1)
|
||||
data = f.read()
|
||||
if flag == 'P':
|
||||
# cPickle.loads() may create lots of objects, disable GC
|
||||
# temporary for better performance
|
||||
gc.disable()
|
||||
try:
|
||||
return cPickle.loads(data)
|
||||
finally:
|
||||
gc.enable()
|
||||
else:
|
||||
return data.decode('utf8') if flag == 'U' else data
|
||||
# pickle.load() may create lots of objects, disable GC
|
||||
# temporary for better performance
|
||||
gc.disable()
|
||||
try:
|
||||
return pickle.load(f)
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
|
|
|
@ -40,164 +40,126 @@ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
|||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import operator
|
||||
import os
|
||||
import io
|
||||
import pickle
|
||||
import struct
|
||||
import sys
|
||||
import types
|
||||
from functools import partial
|
||||
import itertools
|
||||
from copy_reg import _extension_registry, _inverted_registry, _extension_cache
|
||||
import new
|
||||
import dis
|
||||
import traceback
|
||||
import platform
|
||||
|
||||
PyImp = platform.python_implementation()
|
||||
|
||||
|
||||
import logging
|
||||
cloudLog = logging.getLogger("Cloud.Transport")
|
||||
if sys.version < '3':
|
||||
from pickle import Pickler
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO
|
||||
PY3 = False
|
||||
else:
|
||||
types.ClassType = type
|
||||
from pickle import _Pickler as Pickler
|
||||
from io import BytesIO as StringIO
|
||||
PY3 = True
|
||||
|
||||
#relevant opcodes
|
||||
STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL'))
|
||||
DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL'))
|
||||
LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL'))
|
||||
STORE_GLOBAL = dis.opname.index('STORE_GLOBAL')
|
||||
DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL')
|
||||
LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL')
|
||||
GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
|
||||
HAVE_ARGUMENT = dis.HAVE_ARGUMENT
|
||||
EXTENDED_ARG = dis.EXTENDED_ARG
|
||||
|
||||
HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT)
|
||||
EXTENDED_ARG = chr(dis.EXTENDED_ARG)
|
||||
|
||||
if PyImp == "PyPy":
|
||||
# register builtin type in `new`
|
||||
new.method = types.MethodType
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO
|
||||
|
||||
# These helper functions were copied from PiCloud's util module.
|
||||
def islambda(func):
|
||||
return getattr(func,'func_name') == '<lambda>'
|
||||
|
||||
def xrange_params(xrangeobj):
|
||||
"""Returns a 3 element tuple describing the xrange start, step, and len
|
||||
respectively
|
||||
|
||||
Note: Only guarentees that elements of xrange are the same. parameters may
|
||||
be different.
|
||||
e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same
|
||||
though w/ iteration
|
||||
"""
|
||||
|
||||
xrange_len = len(xrangeobj)
|
||||
if not xrange_len: #empty
|
||||
return (0,1,0)
|
||||
start = xrangeobj[0]
|
||||
if xrange_len == 1: #one element
|
||||
return start, 1, 1
|
||||
return (start, xrangeobj[1] - xrangeobj[0], xrange_len)
|
||||
|
||||
#debug variables intended for developer use:
|
||||
printSerialization = False
|
||||
printMemoization = False
|
||||
|
||||
useForcedImports = True #Should I use forced imports for tracking?
|
||||
return getattr(func,'__name__') == '<lambda>'
|
||||
|
||||
|
||||
_BUILTIN_TYPE_NAMES = {}
|
||||
for k, v in types.__dict__.items():
|
||||
if type(v) is type:
|
||||
_BUILTIN_TYPE_NAMES[v] = k
|
||||
|
||||
class CloudPickler(pickle.Pickler):
|
||||
|
||||
dispatch = pickle.Pickler.dispatch.copy()
|
||||
savedForceImports = False
|
||||
savedDjangoEnv = False #hack tro transport django environment
|
||||
def _builtin_type(name):
|
||||
return getattr(types, name)
|
||||
|
||||
def __init__(self, file, protocol=None, min_size_to_save= 0):
|
||||
pickle.Pickler.__init__(self,file,protocol)
|
||||
self.modules = set() #set of modules needed to depickle
|
||||
self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env
|
||||
|
||||
class CloudPickler(Pickler):
|
||||
|
||||
dispatch = Pickler.dispatch.copy()
|
||||
|
||||
def __init__(self, file, protocol=None):
|
||||
Pickler.__init__(self, file, protocol)
|
||||
# set of modules to unpickle
|
||||
self.modules = set()
|
||||
# map ids to dictionary. used to ensure that functions can share global env
|
||||
self.globals_ref = {}
|
||||
|
||||
def dump(self, obj):
|
||||
# note: not thread safe
|
||||
# minimal side-effects, so not fixing
|
||||
recurse_limit = 3000
|
||||
base_recurse = sys.getrecursionlimit()
|
||||
if base_recurse < recurse_limit:
|
||||
sys.setrecursionlimit(recurse_limit)
|
||||
self.inject_addons()
|
||||
try:
|
||||
return pickle.Pickler.dump(self, obj)
|
||||
except RuntimeError, e:
|
||||
return Pickler.dump(self, obj)
|
||||
except RuntimeError as e:
|
||||
if 'recursion' in e.args[0]:
|
||||
msg = """Could not pickle object as excessively deep recursion required.
|
||||
Try _fast_serialization=2 or contact PiCloud support"""
|
||||
msg = """Could not pickle object as excessively deep recursion required."""
|
||||
raise pickle.PicklingError(msg)
|
||||
finally:
|
||||
new_recurse = sys.getrecursionlimit()
|
||||
if new_recurse == recurse_limit:
|
||||
sys.setrecursionlimit(base_recurse)
|
||||
|
||||
def save_memoryview(self, obj):
|
||||
"""Fallback to save_string"""
|
||||
Pickler.save_string(self, str(obj))
|
||||
|
||||
def save_buffer(self, obj):
|
||||
"""Fallback to save_string"""
|
||||
pickle.Pickler.save_string(self,str(obj))
|
||||
dispatch[buffer] = save_buffer
|
||||
Pickler.save_string(self,str(obj))
|
||||
if PY3:
|
||||
dispatch[memoryview] = save_memoryview
|
||||
else:
|
||||
dispatch[buffer] = save_buffer
|
||||
|
||||
#block broken objects
|
||||
def save_unsupported(self, obj, pack=None):
|
||||
def save_unsupported(self, obj):
|
||||
raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj))
|
||||
dispatch[types.GeneratorType] = save_unsupported
|
||||
|
||||
#python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it
|
||||
try:
|
||||
slice(0,1).__reduce__()
|
||||
except TypeError: #can't pickle -
|
||||
dispatch[slice] = save_unsupported
|
||||
|
||||
#itertools objects do not pickle!
|
||||
# itertools objects do not pickle!
|
||||
for v in itertools.__dict__.values():
|
||||
if type(v) is type:
|
||||
dispatch[v] = save_unsupported
|
||||
|
||||
|
||||
def save_dict(self, obj):
|
||||
"""hack fix
|
||||
If the dict is a global, deal with it in a special way
|
||||
"""
|
||||
#print 'saving', obj
|
||||
if obj is __builtins__:
|
||||
self.save_reduce(_get_module_builtins, (), obj=obj)
|
||||
else:
|
||||
pickle.Pickler.save_dict(self, obj)
|
||||
dispatch[pickle.DictionaryType] = save_dict
|
||||
|
||||
|
||||
def save_module(self, obj, pack=struct.pack):
|
||||
def save_module(self, obj):
|
||||
"""
|
||||
Save a module as an import
|
||||
"""
|
||||
#print 'try save import', obj.__name__
|
||||
self.modules.add(obj)
|
||||
self.save_reduce(subimport,(obj.__name__,), obj=obj)
|
||||
dispatch[types.ModuleType] = save_module #new type
|
||||
self.save_reduce(subimport, (obj.__name__,), obj=obj)
|
||||
dispatch[types.ModuleType] = save_module
|
||||
|
||||
def save_codeobject(self, obj, pack=struct.pack):
|
||||
def save_codeobject(self, obj):
|
||||
"""
|
||||
Save a code object
|
||||
"""
|
||||
#print 'try to save codeobj: ', obj
|
||||
args = (
|
||||
obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code,
|
||||
obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name,
|
||||
obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars
|
||||
)
|
||||
if PY3:
|
||||
args = (
|
||||
obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
|
||||
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames,
|
||||
obj.co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
|
||||
obj.co_cellvars
|
||||
)
|
||||
else:
|
||||
args = (
|
||||
obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code,
|
||||
obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name,
|
||||
obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars
|
||||
)
|
||||
self.save_reduce(types.CodeType, args, obj=obj)
|
||||
dispatch[types.CodeType] = save_codeobject #new type
|
||||
dispatch[types.CodeType] = save_codeobject
|
||||
|
||||
def save_function(self, obj, name=None, pack=struct.pack):
|
||||
def save_function(self, obj, name=None):
|
||||
""" Registered with the dispatch to handle all function types.
|
||||
|
||||
Determines what kind of function obj is (e.g. lambda, defined at
|
||||
|
@ -205,12 +167,14 @@ class CloudPickler(pickle.Pickler):
|
|||
"""
|
||||
write = self.write
|
||||
|
||||
name = obj.__name__
|
||||
if name is None:
|
||||
name = obj.__name__
|
||||
modname = pickle.whichmodule(obj, name)
|
||||
#print 'which gives %s %s %s' % (modname, obj, name)
|
||||
# print('which gives %s %s %s' % (modname, obj, name))
|
||||
try:
|
||||
themodule = sys.modules[modname]
|
||||
except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__
|
||||
except KeyError:
|
||||
# eval'd items such as namedtuple give invalid items for their function __module__
|
||||
modname = '__main__'
|
||||
|
||||
if modname == '__main__':
|
||||
|
@ -221,37 +185,18 @@ class CloudPickler(pickle.Pickler):
|
|||
if getattr(themodule, name, None) is obj:
|
||||
return self.save_global(obj, name)
|
||||
|
||||
if not self.savedDjangoEnv:
|
||||
#hack for django - if we detect the settings module, we transport it
|
||||
django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '')
|
||||
if django_settings:
|
||||
django_mod = sys.modules.get(django_settings)
|
||||
if django_mod:
|
||||
cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name)
|
||||
self.savedDjangoEnv = True
|
||||
self.modules.add(django_mod)
|
||||
write(pickle.MARK)
|
||||
self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod)
|
||||
write(pickle.POP_MARK)
|
||||
|
||||
|
||||
# if func is lambda, def'ed at prompt, is in main, or is nested, then
|
||||
# we'll pickle the actual function object rather than simply saving a
|
||||
# reference (as is done in default pickler), via save_function_tuple.
|
||||
if islambda(obj) or obj.func_code.co_filename == '<stdin>' or themodule is None:
|
||||
#Force server to import modules that have been imported in main
|
||||
modList = None
|
||||
if themodule is None and not self.savedForceImports:
|
||||
mainmod = sys.modules['__main__']
|
||||
if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'):
|
||||
modList = list(mainmod.___pyc_forcedImports__)
|
||||
self.savedForceImports = True
|
||||
self.save_function_tuple(obj, modList)
|
||||
if islambda(obj) or obj.__code__.co_filename == '<stdin>' or themodule is None:
|
||||
#print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule)
|
||||
self.save_function_tuple(obj)
|
||||
return
|
||||
else: # func is nested
|
||||
else:
|
||||
# func is nested
|
||||
klass = getattr(themodule, name, None)
|
||||
if klass is None or klass is not obj:
|
||||
self.save_function_tuple(obj, [themodule])
|
||||
self.save_function_tuple(obj)
|
||||
return
|
||||
|
||||
if obj.__dict__:
|
||||
|
@ -266,7 +211,7 @@ class CloudPickler(pickle.Pickler):
|
|||
self.memoize(obj)
|
||||
dispatch[types.FunctionType] = save_function
|
||||
|
||||
def save_function_tuple(self, func, forced_imports):
|
||||
def save_function_tuple(self, func):
|
||||
""" Pickles an actual func object.
|
||||
|
||||
A func comprises: code, globals, defaults, closure, and dict. We
|
||||
|
@ -281,19 +226,6 @@ class CloudPickler(pickle.Pickler):
|
|||
save = self.save
|
||||
write = self.write
|
||||
|
||||
# save the modules (if any)
|
||||
if forced_imports:
|
||||
write(pickle.MARK)
|
||||
save(_modules_to_main)
|
||||
#print 'forced imports are', forced_imports
|
||||
|
||||
forced_names = map(lambda m: m.__name__, forced_imports)
|
||||
save((forced_names,))
|
||||
|
||||
#save((forced_imports,))
|
||||
write(pickle.REDUCE)
|
||||
write(pickle.POP_MARK)
|
||||
|
||||
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
|
||||
|
||||
save(_fill_function) # skeleton function updater
|
||||
|
@ -318,6 +250,8 @@ class CloudPickler(pickle.Pickler):
|
|||
Find all globals names read or written to by codeblock co
|
||||
"""
|
||||
code = co.co_code
|
||||
if not PY3:
|
||||
code = [ord(c) for c in code]
|
||||
names = co.co_names
|
||||
out_names = set()
|
||||
|
||||
|
@ -327,18 +261,18 @@ class CloudPickler(pickle.Pickler):
|
|||
while i < n:
|
||||
op = code[i]
|
||||
|
||||
i = i+1
|
||||
i += 1
|
||||
if op >= HAVE_ARGUMENT:
|
||||
oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
|
||||
oparg = code[i] + code[i+1] * 256 + extended_arg
|
||||
extended_arg = 0
|
||||
i = i+2
|
||||
i += 2
|
||||
if op == EXTENDED_ARG:
|
||||
extended_arg = oparg*65536L
|
||||
extended_arg = oparg*65536
|
||||
if op in GLOBAL_OPS:
|
||||
out_names.add(names[oparg])
|
||||
#print 'extracted', out_names, ' from ', names
|
||||
|
||||
if co.co_consts: # see if nested function have any global refs
|
||||
# see if nested function have any global refs
|
||||
if co.co_consts:
|
||||
for const in co.co_consts:
|
||||
if type(const) is types.CodeType:
|
||||
out_names |= CloudPickler.extract_code_globals(const)
|
||||
|
@ -350,46 +284,28 @@ class CloudPickler(pickle.Pickler):
|
|||
Turn the function into a tuple of data necessary to recreate it:
|
||||
code, globals, defaults, closure, dict
|
||||
"""
|
||||
code = func.func_code
|
||||
code = func.__code__
|
||||
|
||||
# extract all global ref's
|
||||
func_global_refs = CloudPickler.extract_code_globals(code)
|
||||
func_global_refs = self.extract_code_globals(code)
|
||||
|
||||
# process all variables referenced by global environment
|
||||
f_globals = {}
|
||||
for var in func_global_refs:
|
||||
#Some names, such as class functions are not global - we don't need them
|
||||
if func.func_globals.has_key(var):
|
||||
f_globals[var] = func.func_globals[var]
|
||||
if var in func.__globals__:
|
||||
f_globals[var] = func.__globals__[var]
|
||||
|
||||
# defaults requires no processing
|
||||
defaults = func.func_defaults
|
||||
|
||||
def get_contents(cell):
|
||||
try:
|
||||
return cell.cell_contents
|
||||
except ValueError, e: #cell is empty error on not yet assigned
|
||||
raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope')
|
||||
|
||||
defaults = func.__defaults__
|
||||
|
||||
# process closure
|
||||
if func.func_closure:
|
||||
closure = map(get_contents, func.func_closure)
|
||||
else:
|
||||
closure = []
|
||||
closure = [c.cell_contents for c in func.__closure__] if func.__closure__ else []
|
||||
|
||||
# save the dict
|
||||
dct = func.func_dict
|
||||
dct = func.__dict__
|
||||
|
||||
if printSerialization:
|
||||
outvars = ['code: ' + str(code) ]
|
||||
outvars.append('globals: ' + str(f_globals))
|
||||
outvars.append('defaults: ' + str(defaults))
|
||||
outvars.append('closure: ' + str(closure))
|
||||
print 'function ', func, 'is extracted to: ', ', '.join(outvars)
|
||||
|
||||
base_globals = self.globals_ref.get(id(func.func_globals), {})
|
||||
self.globals_ref[id(func.func_globals)] = base_globals
|
||||
base_globals = self.globals_ref.get(id(func.__globals__), {})
|
||||
self.globals_ref[id(func.__globals__)] = base_globals
|
||||
|
||||
return (code, f_globals, defaults, closure, dct, base_globals)
|
||||
|
||||
|
@ -400,8 +316,9 @@ class CloudPickler(pickle.Pickler):
|
|||
dispatch[types.BuiltinFunctionType] = save_builtin_function
|
||||
|
||||
def save_global(self, obj, name=None, pack=struct.pack):
|
||||
write = self.write
|
||||
memo = self.memo
|
||||
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
|
||||
if obj in _BUILTIN_TYPE_NAMES:
|
||||
return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
|
||||
|
||||
if name is None:
|
||||
name = obj.__name__
|
||||
|
@ -410,98 +327,57 @@ class CloudPickler(pickle.Pickler):
|
|||
if modname is None:
|
||||
modname = pickle.whichmodule(obj, name)
|
||||
|
||||
try:
|
||||
__import__(modname)
|
||||
themodule = sys.modules[modname]
|
||||
except (ImportError, KeyError, AttributeError): #should never occur
|
||||
raise pickle.PicklingError(
|
||||
"Can't pickle %r: Module %s cannot be found" %
|
||||
(obj, modname))
|
||||
|
||||
if modname == '__main__':
|
||||
themodule = None
|
||||
|
||||
if themodule:
|
||||
else:
|
||||
__import__(modname)
|
||||
themodule = sys.modules[modname]
|
||||
self.modules.add(themodule)
|
||||
|
||||
sendRef = True
|
||||
typ = type(obj)
|
||||
#print 'saving', obj, typ
|
||||
try:
|
||||
try: #Deal with case when getattribute fails with exceptions
|
||||
klass = getattr(themodule, name)
|
||||
except (AttributeError):
|
||||
if modname == '__builtin__': #new.* are misrepeported
|
||||
modname = 'new'
|
||||
__import__(modname)
|
||||
themodule = sys.modules[modname]
|
||||
try:
|
||||
klass = getattr(themodule, name)
|
||||
except AttributeError, a:
|
||||
# print themodule, name, obj, type(obj)
|
||||
raise pickle.PicklingError("Can't pickle builtin %s" % obj)
|
||||
else:
|
||||
raise
|
||||
if hasattr(themodule, name) and getattr(themodule, name) is obj:
|
||||
return Pickler.save_global(self, obj, name)
|
||||
|
||||
except (ImportError, KeyError, AttributeError):
|
||||
if typ == types.TypeType or typ == types.ClassType:
|
||||
sendRef = False
|
||||
else: #we can't deal with this
|
||||
raise
|
||||
else:
|
||||
if klass is not obj and (typ == types.TypeType or typ == types.ClassType):
|
||||
sendRef = False
|
||||
if not sendRef:
|
||||
#note: Third party types might crash this - add better checks!
|
||||
d = dict(obj.__dict__) #copy dict proxy to a dict
|
||||
if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties
|
||||
d.pop('__dict__',None)
|
||||
d.pop('__weakref__',None)
|
||||
typ = type(obj)
|
||||
if typ is not obj and isinstance(obj, (type, types.ClassType)):
|
||||
d = dict(obj.__dict__) # copy dict proxy to a dict
|
||||
if not isinstance(d.get('__dict__', None), property):
|
||||
# don't extract dict that are properties
|
||||
d.pop('__dict__', None)
|
||||
d.pop('__weakref__', None)
|
||||
|
||||
# hack as __new__ is stored differently in the __dict__
|
||||
new_override = d.get('__new__', None)
|
||||
if new_override:
|
||||
d['__new__'] = obj.__new__
|
||||
|
||||
self.save_reduce(type(obj),(obj.__name__,obj.__bases__,
|
||||
d),obj=obj)
|
||||
#print 'internal reduce dask %s %s' % (obj, d)
|
||||
return
|
||||
self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
|
||||
else:
|
||||
raise pickle.PicklingError("Can't pickle %r" % obj)
|
||||
|
||||
if self.proto >= 2:
|
||||
code = _extension_registry.get((modname, name))
|
||||
if code:
|
||||
assert code > 0
|
||||
if code <= 0xff:
|
||||
write(pickle.EXT1 + chr(code))
|
||||
elif code <= 0xffff:
|
||||
write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8))
|
||||
else:
|
||||
write(pickle.EXT4 + pack("<i", code))
|
||||
return
|
||||
|
||||
write(pickle.GLOBAL + modname + '\n' + name + '\n')
|
||||
self.memoize(obj)
|
||||
dispatch[type] = save_global
|
||||
dispatch[types.ClassType] = save_global
|
||||
dispatch[types.TypeType] = save_global
|
||||
|
||||
def save_instancemethod(self, obj):
|
||||
#Memoization rarely is ever useful due to python bounding
|
||||
self.save_reduce(types.MethodType, (obj.im_func, obj.im_self,obj.im_class), obj=obj)
|
||||
# Memoization rarely is ever useful due to python bounding
|
||||
if PY3:
|
||||
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
|
||||
else:
|
||||
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__),
|
||||
obj=obj)
|
||||
dispatch[types.MethodType] = save_instancemethod
|
||||
|
||||
def save_inst_logic(self, obj):
|
||||
def save_inst(self, obj):
|
||||
"""Inner logic to save instance. Based off pickle.save_inst
|
||||
Supports __transient__"""
|
||||
cls = obj.__class__
|
||||
|
||||
memo = self.memo
|
||||
memo = self.memo
|
||||
write = self.write
|
||||
save = self.save
|
||||
save = self.save
|
||||
|
||||
if hasattr(obj, '__getinitargs__'):
|
||||
args = obj.__getinitargs__()
|
||||
len(args) # XXX Assert it's a sequence
|
||||
len(args) # XXX Assert it's a sequence
|
||||
pickle._keep_alive(args, memo)
|
||||
else:
|
||||
args = ()
|
||||
|
@ -537,15 +413,8 @@ class CloudPickler(pickle.Pickler):
|
|||
save(stuff)
|
||||
write(pickle.BUILD)
|
||||
|
||||
|
||||
def save_inst(self, obj):
|
||||
# Hack to detect PIL Image instances without importing Imaging
|
||||
# PIL can be loaded with multiple names, so we don't check sys.modules for it
|
||||
if hasattr(obj,'im') and hasattr(obj,'palette') and 'Image' in obj.__module__:
|
||||
self.save_image(obj)
|
||||
else:
|
||||
self.save_inst_logic(obj)
|
||||
dispatch[types.InstanceType] = save_inst
|
||||
if not PY3:
|
||||
dispatch[types.InstanceType] = save_inst
|
||||
|
||||
def save_property(self, obj):
|
||||
# properties not correctly saved in python
|
||||
|
@ -592,7 +461,7 @@ class CloudPickler(pickle.Pickler):
|
|||
"""Modified to support __transient__ on new objects
|
||||
Change only affects protocol level 2 (which is always used by PiCloud"""
|
||||
# Assert that args is a tuple or None
|
||||
if not isinstance(args, types.TupleType):
|
||||
if not isinstance(args, tuple):
|
||||
raise pickle.PicklingError("args from reduce() should be a tuple")
|
||||
|
||||
# Assert that func is callable
|
||||
|
@ -646,35 +515,23 @@ class CloudPickler(pickle.Pickler):
|
|||
self._batch_setitems(dictitems)
|
||||
|
||||
if state is not None:
|
||||
#print 'obj %s has state %s' % (obj, state)
|
||||
save(state)
|
||||
write(pickle.BUILD)
|
||||
|
||||
|
||||
def save_xrange(self, obj):
|
||||
"""Save an xrange object in python 2.5
|
||||
Python 2.6 supports this natively
|
||||
"""
|
||||
range_params = xrange_params(obj)
|
||||
self.save_reduce(_build_xrange,range_params)
|
||||
|
||||
#python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it
|
||||
try:
|
||||
xrange(0).__reduce__()
|
||||
except TypeError: #can't pickle -- use PiCloud pickler
|
||||
dispatch[xrange] = save_xrange
|
||||
|
||||
def save_partial(self, obj):
|
||||
"""Partial objects do not serialize correctly in python2.x -- this fixes the bugs"""
|
||||
self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords))
|
||||
|
||||
if sys.version_info < (2,7): #2.7 supports partial pickling
|
||||
if sys.version_info < (2,7): # 2.7 supports partial pickling
|
||||
dispatch[partial] = save_partial
|
||||
|
||||
|
||||
def save_file(self, obj):
|
||||
"""Save a file"""
|
||||
import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute
|
||||
try:
|
||||
import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute
|
||||
except ImportError:
|
||||
import io as pystringIO
|
||||
|
||||
if not hasattr(obj, 'name') or not hasattr(obj, 'mode'):
|
||||
raise pickle.PicklingError("Cannot pickle files that do not map to an actual file")
|
||||
|
@ -720,10 +577,14 @@ class CloudPickler(pickle.Pickler):
|
|||
retval.seek(curloc)
|
||||
|
||||
retval.name = name
|
||||
self.save(retval) #save stringIO
|
||||
self.save(retval)
|
||||
self.memoize(obj)
|
||||
|
||||
dispatch[file] = save_file
|
||||
if PY3:
|
||||
dispatch[io.TextIOWrapper] = save_file
|
||||
else:
|
||||
dispatch[file] = save_file
|
||||
|
||||
"""Special functions for Add-on libraries"""
|
||||
|
||||
def inject_numpy(self):
|
||||
|
@ -732,76 +593,20 @@ class CloudPickler(pickle.Pickler):
|
|||
return
|
||||
self.dispatch[numpy.ufunc] = self.__class__.save_ufunc
|
||||
|
||||
numpy_tst_mods = ['numpy', 'scipy.special']
|
||||
def save_ufunc(self, obj):
|
||||
"""Hack function for saving numpy ufunc objects"""
|
||||
name = obj.__name__
|
||||
for tst_mod_name in self.numpy_tst_mods:
|
||||
numpy_tst_mods = ['numpy', 'scipy.special']
|
||||
for tst_mod_name in numpy_tst_mods:
|
||||
tst_mod = sys.modules.get(tst_mod_name, None)
|
||||
if tst_mod:
|
||||
if name in tst_mod.__dict__:
|
||||
self.save_reduce(_getobject, (tst_mod_name, name))
|
||||
return
|
||||
raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj))
|
||||
|
||||
def inject_timeseries(self):
|
||||
"""Handle bugs with pickling scikits timeseries"""
|
||||
tseries = sys.modules.get('scikits.timeseries.tseries')
|
||||
if not tseries or not hasattr(tseries, 'Timeseries'):
|
||||
return
|
||||
self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries
|
||||
|
||||
def save_timeseries(self, obj):
|
||||
import scikits.timeseries.tseries as ts
|
||||
|
||||
func, reduce_args, state = obj.__reduce__()
|
||||
if func != ts._tsreconstruct:
|
||||
raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func))
|
||||
state = (1,
|
||||
obj.shape,
|
||||
obj.dtype,
|
||||
obj.flags.fnc,
|
||||
obj._data.tostring(),
|
||||
ts.getmaskarray(obj).tostring(),
|
||||
obj._fill_value,
|
||||
obj._dates.shape,
|
||||
obj._dates.__array__().tostring(),
|
||||
obj._dates.dtype, #added -- preserve type
|
||||
obj.freq,
|
||||
obj._optinfo,
|
||||
)
|
||||
return self.save_reduce(_genTimeSeries, (reduce_args, state))
|
||||
|
||||
def inject_email(self):
|
||||
"""Block email LazyImporters from being saved"""
|
||||
email = sys.modules.get('email')
|
||||
if not email:
|
||||
return
|
||||
self.dispatch[email.LazyImporter] = self.__class__.save_unsupported
|
||||
if tst_mod and name in tst_mod.__dict__:
|
||||
return self.save_reduce(_getobject, (tst_mod_name, name))
|
||||
raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in'
|
||||
% str(obj))
|
||||
|
||||
def inject_addons(self):
|
||||
"""Plug in system. Register additional pickling functions if modules already loaded"""
|
||||
self.inject_numpy()
|
||||
self.inject_timeseries()
|
||||
self.inject_email()
|
||||
|
||||
"""Python Imaging Library"""
|
||||
def save_image(self, obj):
|
||||
if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \
|
||||
and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()):
|
||||
#if image not loaded yet -- lazy load
|
||||
self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj)
|
||||
else:
|
||||
#image is loaded - just transmit it over
|
||||
self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj)
|
||||
|
||||
"""
|
||||
def memoize(self, obj):
|
||||
pickle.Pickler.memoize(self, obj)
|
||||
if printMemoization:
|
||||
print 'memoizing ' + str(obj)
|
||||
"""
|
||||
|
||||
|
||||
|
||||
# Shorthands for legacy support
|
||||
|
@ -809,14 +614,13 @@ class CloudPickler(pickle.Pickler):
|
|||
def dump(obj, file, protocol=2):
|
||||
CloudPickler(file, protocol).dump(obj)
|
||||
|
||||
|
||||
def dumps(obj, protocol=2):
|
||||
file = StringIO()
|
||||
|
||||
cp = CloudPickler(file,protocol)
|
||||
cp.dump(obj)
|
||||
|
||||
#print 'cloud dumped', str(obj), str(cp.modules)
|
||||
|
||||
return file.getvalue()
|
||||
|
||||
|
||||
|
@ -825,25 +629,6 @@ def subimport(name):
|
|||
__import__(name)
|
||||
return sys.modules[name]
|
||||
|
||||
#hack to load django settings:
|
||||
def django_settings_load(name):
|
||||
modified_env = False
|
||||
|
||||
if 'DJANGO_SETTINGS_MODULE' not in os.environ:
|
||||
os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps
|
||||
modified_env = True
|
||||
try:
|
||||
module = subimport(name)
|
||||
except Exception, i:
|
||||
print >> sys.stderr, 'Cloud not import django settings %s:' % (name)
|
||||
print_exec(sys.stderr)
|
||||
if modified_env:
|
||||
del os.environ['DJANGO_SETTINGS_MODULE']
|
||||
else:
|
||||
#add project directory to sys,path:
|
||||
if hasattr(module,'__file__'):
|
||||
dirname = os.path.split(module.__file__)[0] + '/'
|
||||
sys.path.append(dirname)
|
||||
|
||||
# restores function attributes
|
||||
def _restore_attr(obj, attr):
|
||||
|
@ -851,13 +636,16 @@ def _restore_attr(obj, attr):
|
|||
setattr(obj, key, val)
|
||||
return obj
|
||||
|
||||
|
||||
def _get_module_builtins():
|
||||
return pickle.__builtins__
|
||||
|
||||
|
||||
def print_exec(stream):
|
||||
ei = sys.exc_info()
|
||||
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
|
||||
|
||||
|
||||
def _modules_to_main(modList):
|
||||
"""Force every module in modList to be placed into main"""
|
||||
if not modList:
|
||||
|
@ -868,22 +656,16 @@ def _modules_to_main(modList):
|
|||
if type(modname) is str:
|
||||
try:
|
||||
mod = __import__(modname)
|
||||
except Exception, i: #catch all...
|
||||
sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \
|
||||
A version mismatch is likely. Specific error was:\n' % modname)
|
||||
except Exception as e:
|
||||
sys.stderr.write('warning: could not import %s\n. '
|
||||
'Your function may unexpectedly error due to this import failing;'
|
||||
'A version mismatch is likely. Specific error was:\n' % modname)
|
||||
print_exec(sys.stderr)
|
||||
else:
|
||||
setattr(main,mod.__name__, mod)
|
||||
else:
|
||||
#REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD)
|
||||
#In old version actual module was sent
|
||||
setattr(main,modname.__name__, modname)
|
||||
setattr(main, mod.__name__, mod)
|
||||
|
||||
|
||||
#object generators:
|
||||
def _build_xrange(start, step, len):
|
||||
"""Built xrange explicitly"""
|
||||
return xrange(start, start + step*len, step)
|
||||
|
||||
def _genpartial(func, args, kwds):
|
||||
if not args:
|
||||
args = ()
|
||||
|
@ -891,22 +673,26 @@ def _genpartial(func, args, kwds):
|
|||
kwds = {}
|
||||
return partial(func, *args, **kwds)
|
||||
|
||||
|
||||
def _fill_function(func, globals, defaults, dict):
|
||||
""" Fills in the rest of function data into the skeleton function object
|
||||
that were created via _make_skel_func().
|
||||
"""
|
||||
func.func_globals.update(globals)
|
||||
func.func_defaults = defaults
|
||||
func.func_dict = dict
|
||||
func.__globals__.update(globals)
|
||||
func.__defaults__ = defaults
|
||||
func.__dict__ = dict
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _make_cell(value):
|
||||
return (lambda: value).func_closure[0]
|
||||
return (lambda: value).__closure__[0]
|
||||
|
||||
|
||||
def _reconstruct_closure(values):
|
||||
return tuple([_make_cell(v) for v in values])
|
||||
|
||||
|
||||
def _make_skel_func(code, closures, base_globals = None):
|
||||
""" Creates a skeleton function object that contains just the provided
|
||||
code and the correct number of cells in func_closure. All other
|
||||
|
@ -928,40 +714,3 @@ Note: These can never be renamed due to client compatibility issues"""
|
|||
def _getobject(modname, attribute):
|
||||
mod = __import__(modname, fromlist=[attribute])
|
||||
return mod.__dict__[attribute]
|
||||
|
||||
def _generateImage(size, mode, str_rep):
|
||||
"""Generate image from string representation"""
|
||||
import Image
|
||||
i = Image.new(mode, size)
|
||||
i.fromstring(str_rep)
|
||||
return i
|
||||
|
||||
def _lazyloadImage(fp):
|
||||
import Image
|
||||
fp.seek(0) #works in almost any case
|
||||
return Image.open(fp)
|
||||
|
||||
"""Timeseries"""
|
||||
def _genTimeSeries(reduce_args, state):
|
||||
import scikits.timeseries.tseries as ts
|
||||
from numpy import ndarray
|
||||
from numpy.ma import MaskedArray
|
||||
|
||||
|
||||
time_series = ts._tsreconstruct(*reduce_args)
|
||||
|
||||
#from setstate modified
|
||||
(ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state
|
||||
#print 'regenerating %s' % dtyp
|
||||
|
||||
MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv))
|
||||
_dates = time_series._dates
|
||||
#_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ
|
||||
ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm))
|
||||
_dates.freq = frq
|
||||
_dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None,
|
||||
toobj=None, toord=None, tostr=None))
|
||||
# Update the _optinfo dictionary
|
||||
time_series._optinfo.update(infodict)
|
||||
return time_series
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ u'/path'
|
|||
<pyspark.conf.SparkConf object at ...>
|
||||
>>> conf.get("spark.executorEnv.VAR1")
|
||||
u'value1'
|
||||
>>> print conf.toDebugString()
|
||||
>>> print(conf.toDebugString())
|
||||
spark.executorEnv.VAR1=value1
|
||||
spark.executorEnv.VAR3=value3
|
||||
spark.executorEnv.VAR4=value4
|
||||
|
@ -56,6 +56,13 @@ spark.home=/path
|
|||
|
||||
__all__ = ['SparkConf']
|
||||
|
||||
import sys
|
||||
import re
|
||||
|
||||
if sys.version > '3':
|
||||
unicode = str
|
||||
__doc__ = re.sub(r"(\W|^)[uU](['])", r'\1\2', __doc__)
|
||||
|
||||
|
||||
class SparkConf(object):
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
@ -32,11 +34,14 @@ from pyspark.java_gateway import launch_gateway
|
|||
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
|
||||
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
from pyspark.rdd import RDD, _load_from_socket
|
||||
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
|
||||
from pyspark.traceback_utils import CallSite, first_spark_call
|
||||
from pyspark.status import StatusTracker
|
||||
from pyspark.profiler import ProfilerCollector, BasicProfiler
|
||||
|
||||
if sys.version > '3':
|
||||
xrange = range
|
||||
|
||||
|
||||
__all__ = ['SparkContext']
|
||||
|
||||
|
@ -133,7 +138,7 @@ class SparkContext(object):
|
|||
if sparkHome:
|
||||
self._conf.setSparkHome(sparkHome)
|
||||
if environment:
|
||||
for key, value in environment.iteritems():
|
||||
for key, value in environment.items():
|
||||
self._conf.setExecutorEnv(key, value)
|
||||
for key, value in DEFAULT_CONFIGS.items():
|
||||
self._conf.setIfMissing(key, value)
|
||||
|
@ -153,6 +158,10 @@ class SparkContext(object):
|
|||
if k.startswith("spark.executorEnv."):
|
||||
varName = k[len("spark.executorEnv."):]
|
||||
self.environment[varName] = v
|
||||
if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ:
|
||||
# disable randomness of hash of string in worker, if this is not
|
||||
# launched by spark-submit
|
||||
self.environment["PYTHONHASHSEED"] = "0"
|
||||
|
||||
# Create the Java SparkContext through Py4J
|
||||
self._jsc = jsc or self._initialize_context(self._conf._jconf)
|
||||
|
@ -323,7 +332,7 @@ class SparkContext(object):
|
|||
start0 = c[0]
|
||||
|
||||
def getStart(split):
|
||||
return start0 + (split * size / numSlices) * step
|
||||
return start0 + int((split * size / numSlices)) * step
|
||||
|
||||
def f(split, iterator):
|
||||
return xrange(getStart(split), getStart(split + 1), step)
|
||||
|
@ -357,6 +366,7 @@ class SparkContext(object):
|
|||
minPartitions = minPartitions or self.defaultMinPartitions
|
||||
return RDD(self._jsc.objectFile(name, minPartitions), self)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def textFile(self, name, minPartitions=None, use_unicode=True):
|
||||
"""
|
||||
Read a text file from HDFS, a local file system (available on all
|
||||
|
@ -369,7 +379,7 @@ class SparkContext(object):
|
|||
|
||||
>>> path = os.path.join(tempdir, "sample-text.txt")
|
||||
>>> with open(path, "w") as testFile:
|
||||
... testFile.write("Hello world!")
|
||||
... _ = testFile.write("Hello world!")
|
||||
>>> textFile = sc.textFile(path)
|
||||
>>> textFile.collect()
|
||||
[u'Hello world!']
|
||||
|
@ -378,6 +388,7 @@ class SparkContext(object):
|
|||
return RDD(self._jsc.textFile(name, minPartitions), self,
|
||||
UTF8Deserializer(use_unicode))
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def wholeTextFiles(self, path, minPartitions=None, use_unicode=True):
|
||||
"""
|
||||
Read a directory of text files from HDFS, a local file system
|
||||
|
@ -411,9 +422,9 @@ class SparkContext(object):
|
|||
>>> dirPath = os.path.join(tempdir, "files")
|
||||
>>> os.mkdir(dirPath)
|
||||
>>> with open(os.path.join(dirPath, "1.txt"), "w") as file1:
|
||||
... file1.write("1")
|
||||
... _ = file1.write("1")
|
||||
>>> with open(os.path.join(dirPath, "2.txt"), "w") as file2:
|
||||
... file2.write("2")
|
||||
... _ = file2.write("2")
|
||||
>>> textFiles = sc.wholeTextFiles(dirPath)
|
||||
>>> sorted(textFiles.collect())
|
||||
[(u'.../1.txt', u'1'), (u'.../2.txt', u'2')]
|
||||
|
@ -456,7 +467,7 @@ class SparkContext(object):
|
|||
jm = self._jvm.java.util.HashMap()
|
||||
if not d:
|
||||
d = {}
|
||||
for k, v in d.iteritems():
|
||||
for k, v in d.items():
|
||||
jm[k] = v
|
||||
return jm
|
||||
|
||||
|
@ -608,6 +619,7 @@ class SparkContext(object):
|
|||
jrdd = self._jsc.checkpointFile(name)
|
||||
return RDD(jrdd, self, input_deserializer)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def union(self, rdds):
|
||||
"""
|
||||
Build the union of a list of RDDs.
|
||||
|
@ -618,7 +630,7 @@ class SparkContext(object):
|
|||
|
||||
>>> path = os.path.join(tempdir, "union-text.txt")
|
||||
>>> with open(path, "w") as testFile:
|
||||
... testFile.write("Hello")
|
||||
... _ = testFile.write("Hello")
|
||||
>>> textFile = sc.textFile(path)
|
||||
>>> textFile.collect()
|
||||
[u'Hello']
|
||||
|
@ -677,7 +689,7 @@ class SparkContext(object):
|
|||
>>> from pyspark import SparkFiles
|
||||
>>> path = os.path.join(tempdir, "test.txt")
|
||||
>>> with open(path, "w") as testFile:
|
||||
... testFile.write("100")
|
||||
... _ = testFile.write("100")
|
||||
>>> sc.addFile(path)
|
||||
>>> def func(iterator):
|
||||
... with open(SparkFiles.get("test.txt")) as testFile:
|
||||
|
@ -705,11 +717,13 @@ class SparkContext(object):
|
|||
"""
|
||||
self.addFile(path)
|
||||
(dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix
|
||||
|
||||
if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
|
||||
self._python_includes.append(filename)
|
||||
# for tests in local mode
|
||||
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
|
||||
if sys.version > '3':
|
||||
import importlib
|
||||
importlib.invalidate_caches()
|
||||
|
||||
def setCheckpointDir(self, dirName):
|
||||
"""
|
||||
|
@ -744,7 +758,7 @@ class SparkContext(object):
|
|||
The application can use L{SparkContext.cancelJobGroup} to cancel all
|
||||
running jobs in this group.
|
||||
|
||||
>>> import thread, threading
|
||||
>>> import threading
|
||||
>>> from time import sleep
|
||||
>>> result = "Not Set"
|
||||
>>> lock = threading.Lock()
|
||||
|
@ -763,10 +777,10 @@ class SparkContext(object):
|
|||
... sleep(5)
|
||||
... sc.cancelJobGroup("job_to_cancel")
|
||||
>>> supress = lock.acquire()
|
||||
>>> supress = thread.start_new_thread(start_job, (10,))
|
||||
>>> supress = thread.start_new_thread(stop_job, tuple())
|
||||
>>> supress = threading.Thread(target=start_job, args=(10,)).start()
|
||||
>>> supress = threading.Thread(target=stop_job).start()
|
||||
>>> supress = lock.acquire()
|
||||
>>> print result
|
||||
>>> print(result)
|
||||
Cancelled
|
||||
|
||||
If interruptOnCancel is set to true for the job group, then job cancellation will result
|
||||
|
|
|
@ -24,9 +24,10 @@ import sys
|
|||
import traceback
|
||||
import time
|
||||
import gc
|
||||
from errno import EINTR, ECHILD, EAGAIN
|
||||
from errno import EINTR, EAGAIN
|
||||
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
|
||||
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
|
||||
|
||||
from pyspark.worker import main as worker_main
|
||||
from pyspark.serializers import read_int, write_int
|
||||
|
||||
|
@ -53,8 +54,8 @@ def worker(sock):
|
|||
# Read the socket using fdopen instead of socket.makefile() because the latter
|
||||
# seems to be very slow; note that we need to dup() the file descriptor because
|
||||
# otherwise writes also cause a seek that makes us miss data on the read side.
|
||||
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
|
||||
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
|
||||
infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536)
|
||||
outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536)
|
||||
exit_code = 0
|
||||
try:
|
||||
worker_main(infile, outfile)
|
||||
|
@ -68,17 +69,6 @@ def worker(sock):
|
|||
return exit_code
|
||||
|
||||
|
||||
# Cleanup zombie children
|
||||
def cleanup_dead_children():
|
||||
try:
|
||||
while True:
|
||||
pid, _ = os.waitpid(0, os.WNOHANG)
|
||||
if not pid:
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def manager():
|
||||
# Create a new process group to corral our children
|
||||
os.setpgid(0, 0)
|
||||
|
@ -88,8 +78,12 @@ def manager():
|
|||
listen_sock.bind(('127.0.0.1', 0))
|
||||
listen_sock.listen(max(1024, SOMAXCONN))
|
||||
listen_host, listen_port = listen_sock.getsockname()
|
||||
write_int(listen_port, sys.stdout)
|
||||
sys.stdout.flush()
|
||||
|
||||
# re-open stdin/stdout in 'wb' mode
|
||||
stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4)
|
||||
stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4)
|
||||
write_int(listen_port, stdout_bin)
|
||||
stdout_bin.flush()
|
||||
|
||||
def shutdown(code):
|
||||
signal.signal(SIGTERM, SIG_DFL)
|
||||
|
@ -101,6 +95,7 @@ def manager():
|
|||
shutdown(1)
|
||||
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
|
||||
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP
|
||||
signal.signal(SIGCHLD, SIG_IGN)
|
||||
|
||||
reuse = os.environ.get("SPARK_REUSE_WORKER")
|
||||
|
||||
|
@ -115,12 +110,9 @@ def manager():
|
|||
else:
|
||||
raise
|
||||
|
||||
# cleanup in signal handler will cause deadlock
|
||||
cleanup_dead_children()
|
||||
|
||||
if 0 in ready_fds:
|
||||
try:
|
||||
worker_pid = read_int(sys.stdin)
|
||||
worker_pid = read_int(stdin_bin)
|
||||
except EOFError:
|
||||
# Spark told us to exit by closing stdin
|
||||
shutdown(0)
|
||||
|
@ -145,7 +137,7 @@ def manager():
|
|||
time.sleep(1)
|
||||
pid = os.fork() # error here will shutdown daemon
|
||||
else:
|
||||
outfile = sock.makefile('w')
|
||||
outfile = sock.makefile(mode='wb')
|
||||
write_int(e.errno, outfile) # Signal that the fork failed
|
||||
outfile.flush()
|
||||
outfile.close()
|
||||
|
@ -157,7 +149,7 @@ def manager():
|
|||
listen_sock.close()
|
||||
try:
|
||||
# Acknowledge that the fork was successful
|
||||
outfile = sock.makefile("w")
|
||||
outfile = sock.makefile(mode="wb")
|
||||
write_int(os.getpid(), outfile)
|
||||
outfile.flush()
|
||||
outfile.close()
|
||||
|
|
|
@ -627,51 +627,49 @@ def merge(iterables, key=None, reverse=False):
|
|||
if key is None:
|
||||
for order, it in enumerate(map(iter, iterables)):
|
||||
try:
|
||||
next = it.next
|
||||
h_append([next(), order * direction, next])
|
||||
h_append([next(it), order * direction, it])
|
||||
except StopIteration:
|
||||
pass
|
||||
_heapify(h)
|
||||
while len(h) > 1:
|
||||
try:
|
||||
while True:
|
||||
value, order, next = s = h[0]
|
||||
value, order, it = s = h[0]
|
||||
yield value
|
||||
s[0] = next() # raises StopIteration when exhausted
|
||||
s[0] = next(it) # raises StopIteration when exhausted
|
||||
_heapreplace(h, s) # restore heap condition
|
||||
except StopIteration:
|
||||
_heappop(h) # remove empty iterator
|
||||
if h:
|
||||
# fast case when only a single iterator remains
|
||||
value, order, next = h[0]
|
||||
value, order, it = h[0]
|
||||
yield value
|
||||
for value in next.__self__:
|
||||
for value in it:
|
||||
yield value
|
||||
return
|
||||
|
||||
for order, it in enumerate(map(iter, iterables)):
|
||||
try:
|
||||
next = it.next
|
||||
value = next()
|
||||
h_append([key(value), order * direction, value, next])
|
||||
value = next(it)
|
||||
h_append([key(value), order * direction, value, it])
|
||||
except StopIteration:
|
||||
pass
|
||||
_heapify(h)
|
||||
while len(h) > 1:
|
||||
try:
|
||||
while True:
|
||||
key_value, order, value, next = s = h[0]
|
||||
key_value, order, value, it = s = h[0]
|
||||
yield value
|
||||
value = next()
|
||||
value = next(it)
|
||||
s[0] = key(value)
|
||||
s[2] = value
|
||||
_heapreplace(h, s)
|
||||
except StopIteration:
|
||||
_heappop(h)
|
||||
if h:
|
||||
key_value, order, value, next = h[0]
|
||||
key_value, order, value, it = h[0]
|
||||
yield value
|
||||
for value in next.__self__:
|
||||
for value in it:
|
||||
yield value
|
||||
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ def launch_gateway():
|
|||
if callback_socket in readable:
|
||||
gateway_connection = callback_socket.accept()[0]
|
||||
# Determine which ephemeral port the server started on:
|
||||
gateway_port = read_int(gateway_connection.makefile())
|
||||
gateway_port = read_int(gateway_connection.makefile(mode="rb"))
|
||||
gateway_connection.close()
|
||||
callback_socket.close()
|
||||
if gateway_port is None:
|
||||
|
|
|
@ -32,6 +32,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
"""
|
||||
|
||||
from pyspark.resultiterable import ResultIterable
|
||||
from functools import reduce
|
||||
|
||||
|
||||
def _do_python_join(rdd, other, numPartitions, dispatch):
|
||||
|
|
|
@ -39,10 +39,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
>>> lr = LogisticRegression(maxIter=5, regParam=0.01)
|
||||
>>> model = lr.fit(df)
|
||||
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
|
||||
>>> print model.transform(test0).head().prediction
|
||||
>>> model.transform(test0).head().prediction
|
||||
0.0
|
||||
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
|
||||
>>> print model.transform(test1).head().prediction
|
||||
>>> model.transform(test1).head().prediction
|
||||
1.0
|
||||
>>> lr.setParams("vector")
|
||||
Traceback (most recent call last):
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from pyspark.rdd import ignore_unicode_prefix
|
||||
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
|
||||
from pyspark.ml.util import keyword_only
|
||||
from pyspark.ml.wrapper import JavaTransformer
|
||||
|
@ -24,6 +25,7 @@ __all__ = ['Tokenizer', 'HashingTF']
|
|||
|
||||
|
||||
@inherit_doc
|
||||
@ignore_unicode_prefix
|
||||
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
|
||||
"""
|
||||
A tokenizer that converts the input string to lowercase and then
|
||||
|
@ -32,15 +34,15 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
|
|||
>>> from pyspark.sql import Row
|
||||
>>> df = sc.parallelize([Row(text="a b c")]).toDF()
|
||||
>>> tokenizer = Tokenizer(inputCol="text", outputCol="words")
|
||||
>>> print tokenizer.transform(df).head()
|
||||
>>> tokenizer.transform(df).head()
|
||||
Row(text=u'a b c', words=[u'a', u'b', u'c'])
|
||||
>>> # Change a parameter.
|
||||
>>> print tokenizer.setParams(outputCol="tokens").transform(df).head()
|
||||
>>> tokenizer.setParams(outputCol="tokens").transform(df).head()
|
||||
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
|
||||
>>> # Temporarily modify a parameter.
|
||||
>>> print tokenizer.transform(df, {tokenizer.outputCol: "words"}).head()
|
||||
>>> tokenizer.transform(df, {tokenizer.outputCol: "words"}).head()
|
||||
Row(text=u'a b c', words=[u'a', u'b', u'c'])
|
||||
>>> print tokenizer.transform(df).head()
|
||||
>>> tokenizer.transform(df).head()
|
||||
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
|
||||
>>> # Must use keyword arguments to specify params.
|
||||
>>> tokenizer.setParams("text")
|
||||
|
@ -79,13 +81,13 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
|
|||
>>> from pyspark.sql import Row
|
||||
>>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF()
|
||||
>>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
|
||||
>>> print hashingTF.transform(df).head().features
|
||||
(10,[7,8,9],[1.0,1.0,1.0])
|
||||
>>> print hashingTF.setParams(outputCol="freqs").transform(df).head().freqs
|
||||
(10,[7,8,9],[1.0,1.0,1.0])
|
||||
>>> hashingTF.transform(df).head().features
|
||||
SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0})
|
||||
>>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs
|
||||
SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0})
|
||||
>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
|
||||
>>> print hashingTF.transform(df, params).head().vector
|
||||
(5,[2,3,4],[1.0,1.0,1.0])
|
||||
>>> hashingTF.transform(df, params).head().vector
|
||||
SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0})
|
||||
"""
|
||||
|
||||
_java_class = "org.apache.spark.ml.feature.HashingTF"
|
||||
|
|
|
@ -63,8 +63,8 @@ class Params(Identifiable):
|
|||
uses :py:func:`dir` to get all attributes of type
|
||||
:py:class:`Param`.
|
||||
"""
|
||||
return filter(lambda attr: isinstance(attr, Param),
|
||||
[getattr(self, x) for x in dir(self) if x != "params"])
|
||||
return list(filter(lambda attr: isinstance(attr, Param),
|
||||
[getattr(self, x) for x in dir(self) if x != "params"]))
|
||||
|
||||
def _explain(self, param):
|
||||
"""
|
||||
|
@ -185,7 +185,7 @@ class Params(Identifiable):
|
|||
"""
|
||||
Sets user-supplied params.
|
||||
"""
|
||||
for param, value in kwargs.iteritems():
|
||||
for param, value in kwargs.items():
|
||||
self.paramMap[getattr(self, param)] = value
|
||||
return self
|
||||
|
||||
|
@ -193,6 +193,6 @@ class Params(Identifiable):
|
|||
"""
|
||||
Sets default params.
|
||||
"""
|
||||
for param, value in kwargs.iteritems():
|
||||
for param, value in kwargs.items():
|
||||
self.defaultParamMap[getattr(self, param)] = value
|
||||
return self
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
header = """#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
|
@ -82,9 +84,9 @@ def _gen_param_code(name, doc, defaultValueStr):
|
|||
.replace("$defaultValueStr", str(defaultValueStr))
|
||||
|
||||
if __name__ == "__main__":
|
||||
print header
|
||||
print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n"
|
||||
print "from pyspark.ml.param import Param, Params\n\n"
|
||||
print(header)
|
||||
print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n")
|
||||
print("from pyspark.ml.param import Param, Params\n\n")
|
||||
shared = [
|
||||
("maxIter", "max number of iterations", None),
|
||||
("regParam", "regularization constant", None),
|
||||
|
@ -97,4 +99,4 @@ if __name__ == "__main__":
|
|||
code = []
|
||||
for name, doc, defaultValueStr in shared:
|
||||
code.append(_gen_param_code(name, doc, defaultValueStr))
|
||||
print "\n\n\n".join(code)
|
||||
print("\n\n\n".join(code))
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
"""
|
||||
Python bindings for MLlib.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
# MLlib currently needs NumPy 1.4+, so complain if lower
|
||||
|
||||
|
@ -29,7 +30,9 @@ __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random',
|
|||
'recommendation', 'regression', 'stat', 'tree', 'util']
|
||||
|
||||
import sys
|
||||
import rand as random
|
||||
random.__name__ = 'random'
|
||||
random.RandomRDDs.__module__ = __name__ + '.random'
|
||||
sys.modules[__name__ + '.random'] = random
|
||||
from . import rand as random
|
||||
modname = __name__ + '.random'
|
||||
random.__name__ = modname
|
||||
random.RandomRDDs.__module__ = modname
|
||||
sys.modules[modname] = random
|
||||
del modname, sys
|
||||
|
|
|
@ -510,9 +510,10 @@ class NaiveBayesModel(Saveable, Loader):
|
|||
def load(cls, sc, path):
|
||||
java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load(
|
||||
sc._jsc.sc(), path)
|
||||
py_labels = _java2py(sc, java_model.labels())
|
||||
py_pi = _java2py(sc, java_model.pi())
|
||||
py_theta = _java2py(sc, java_model.theta())
|
||||
# Can not unpickle array.array from Pyrolite in Python3 with "bytes"
|
||||
py_labels = _java2py(sc, java_model.labels(), "latin1")
|
||||
py_pi = _java2py(sc, java_model.pi(), "latin1")
|
||||
py_theta = _java2py(sc, java_model.theta(), "latin1")
|
||||
return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta))
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,12 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import array as pyarray
|
||||
|
||||
if sys.version > '3':
|
||||
xrange = range
|
||||
|
||||
from numpy import array
|
||||
|
||||
from pyspark import RDD
|
||||
|
@ -55,8 +61,8 @@ class KMeansModel(Saveable, Loader):
|
|||
True
|
||||
>>> model.predict(sparse_data[2]) == model.predict(sparse_data[3])
|
||||
True
|
||||
>>> type(model.clusterCenters)
|
||||
<type 'list'>
|
||||
>>> isinstance(model.clusterCenters, list)
|
||||
True
|
||||
>>> import os, tempfile
|
||||
>>> path = tempfile.mkdtemp()
|
||||
>>> model.save(sc, path)
|
||||
|
@ -90,7 +96,7 @@ class KMeansModel(Saveable, Loader):
|
|||
return best
|
||||
|
||||
def save(self, sc, path):
|
||||
java_centers = _py2java(sc, map(_convert_to_vector, self.centers))
|
||||
java_centers = _py2java(sc, [_convert_to_vector(c) for c in self.centers])
|
||||
java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers)
|
||||
java_model.save(sc._jsc.sc(), path)
|
||||
|
||||
|
@ -133,7 +139,7 @@ class GaussianMixtureModel(object):
|
|||
... 5.7048, 4.6567, 5.5026,
|
||||
... 4.5605, 5.2043, 6.2734]).reshape(5, 3))
|
||||
>>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
|
||||
... maxIterations=150, seed=10)
|
||||
... maxIterations=150, seed=10)
|
||||
>>> labels = model.predict(clusterdata_2).collect()
|
||||
>>> labels[0]==labels[1]==labels[2]
|
||||
True
|
||||
|
@ -168,8 +174,8 @@ class GaussianMixtureModel(object):
|
|||
if isinstance(x, RDD):
|
||||
means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
|
||||
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
|
||||
self.weights, means, sigmas)
|
||||
return membership_matrix
|
||||
_convert_to_vector(self.weights), means, sigmas)
|
||||
return membership_matrix.map(lambda x: pyarray.array('d', x))
|
||||
|
||||
|
||||
class GaussianMixture(object):
|
||||
|
|
|
@ -15,6 +15,11 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
if sys.version >= '3':
|
||||
long = int
|
||||
unicode = str
|
||||
|
||||
import py4j.protocol
|
||||
from py4j.protocol import Py4JJavaError
|
||||
from py4j.java_gateway import JavaObject
|
||||
|
@ -36,7 +41,7 @@ _float_str_mapping = {
|
|||
|
||||
def _new_smart_decode(obj):
|
||||
if isinstance(obj, float):
|
||||
s = unicode(obj)
|
||||
s = str(obj)
|
||||
return _float_str_mapping.get(s, s)
|
||||
return _old_smart_decode(obj)
|
||||
|
||||
|
@ -74,15 +79,15 @@ def _py2java(sc, obj):
|
|||
obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
|
||||
elif isinstance(obj, JavaObject):
|
||||
pass
|
||||
elif isinstance(obj, (int, long, float, bool, basestring)):
|
||||
elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
|
||||
pass
|
||||
else:
|
||||
bytes = bytearray(PickleSerializer().dumps(obj))
|
||||
obj = sc._jvm.SerDe.loads(bytes)
|
||||
data = bytearray(PickleSerializer().dumps(obj))
|
||||
obj = sc._jvm.SerDe.loads(data)
|
||||
return obj
|
||||
|
||||
|
||||
def _java2py(sc, r):
|
||||
def _java2py(sc, r, encoding="bytes"):
|
||||
if isinstance(r, JavaObject):
|
||||
clsName = r.getClass().getSimpleName()
|
||||
# convert RDD into JavaRDD
|
||||
|
@ -102,8 +107,8 @@ def _java2py(sc, r):
|
|||
except Py4JJavaError:
|
||||
pass # not pickable
|
||||
|
||||
if isinstance(r, bytearray):
|
||||
r = PickleSerializer().loads(str(r))
|
||||
if isinstance(r, (bytearray, bytes)):
|
||||
r = PickleSerializer().loads(bytes(r), encoding=encoding)
|
||||
return r
|
||||
|
||||
|
||||
|
|
|
@ -23,12 +23,17 @@ from __future__ import absolute_import
|
|||
import sys
|
||||
import warnings
|
||||
import random
|
||||
import binascii
|
||||
if sys.version >= '3':
|
||||
basestring = str
|
||||
unicode = str
|
||||
|
||||
from py4j.protocol import Py4JJavaError
|
||||
|
||||
from pyspark import RDD, SparkContext
|
||||
from pyspark import SparkContext
|
||||
from pyspark.rdd import RDD, ignore_unicode_prefix
|
||||
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
|
||||
from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector
|
||||
from pyspark.mllib.linalg import Vectors, _convert_to_vector
|
||||
|
||||
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
|
||||
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
|
||||
|
@ -206,7 +211,7 @@ class HashingTF(object):
|
|||
>>> htf = HashingTF(100)
|
||||
>>> doc = "a a b b c d".split(" ")
|
||||
>>> htf.transform(doc)
|
||||
SparseVector(100, {1: 1.0, 14: 1.0, 31: 2.0, 44: 2.0})
|
||||
SparseVector(100, {...})
|
||||
"""
|
||||
def __init__(self, numFeatures=1 << 20):
|
||||
"""
|
||||
|
@ -360,6 +365,7 @@ class Word2VecModel(JavaVectorTransformer):
|
|||
return self.call("getVectors")
|
||||
|
||||
|
||||
@ignore_unicode_prefix
|
||||
class Word2Vec(object):
|
||||
"""
|
||||
Word2Vec creates vector representation of words in a text corpus.
|
||||
|
@ -382,7 +388,7 @@ class Word2Vec(object):
|
|||
>>> sentence = "a b " * 100 + "a c " * 10
|
||||
>>> localDoc = [sentence, sentence]
|
||||
>>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
|
||||
>>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
|
||||
>>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc)
|
||||
|
||||
>>> syms = model.findSynonyms("a", 2)
|
||||
>>> [s[0] for s in syms]
|
||||
|
@ -400,7 +406,7 @@ class Word2Vec(object):
|
|||
self.learningRate = 0.025
|
||||
self.numPartitions = 1
|
||||
self.numIterations = 1
|
||||
self.seed = random.randint(0, sys.maxint)
|
||||
self.seed = random.randint(0, sys.maxsize)
|
||||
self.minCount = 5
|
||||
|
||||
def setVectorSize(self, vectorSize):
|
||||
|
@ -459,7 +465,7 @@ class Word2Vec(object):
|
|||
raise TypeError("data should be an RDD of list of string")
|
||||
jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
|
||||
float(self.learningRate), int(self.numPartitions),
|
||||
int(self.numIterations), long(self.seed),
|
||||
int(self.numIterations), int(self.seed),
|
||||
int(self.minCount))
|
||||
return Word2VecModel(jmodel)
|
||||
|
||||
|
|
|
@ -16,12 +16,14 @@
|
|||
#
|
||||
|
||||
from pyspark import SparkContext
|
||||
from pyspark.rdd import ignore_unicode_prefix
|
||||
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
|
||||
|
||||
__all__ = ['FPGrowth', 'FPGrowthModel']
|
||||
|
||||
|
||||
@inherit_doc
|
||||
@ignore_unicode_prefix
|
||||
class FPGrowthModel(JavaModelWrapper):
|
||||
|
||||
"""
|
||||
|
|
|
@ -25,7 +25,13 @@ SciPy is available in their environment.
|
|||
|
||||
import sys
|
||||
import array
|
||||
import copy_reg
|
||||
|
||||
if sys.version >= '3':
|
||||
basestring = str
|
||||
xrange = range
|
||||
import copyreg as copy_reg
|
||||
else:
|
||||
import copy_reg
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -57,7 +63,7 @@ except:
|
|||
def _convert_to_vector(l):
|
||||
if isinstance(l, Vector):
|
||||
return l
|
||||
elif type(l) in (array.array, np.array, np.ndarray, list, tuple):
|
||||
elif type(l) in (array.array, np.array, np.ndarray, list, tuple, xrange):
|
||||
return DenseVector(l)
|
||||
elif _have_scipy and scipy.sparse.issparse(l):
|
||||
assert l.shape[1] == 1, "Expected column vector"
|
||||
|
@ -88,7 +94,7 @@ def _vector_size(v):
|
|||
"""
|
||||
if isinstance(v, Vector):
|
||||
return len(v)
|
||||
elif type(v) in (array.array, list, tuple):
|
||||
elif type(v) in (array.array, list, tuple, xrange):
|
||||
return len(v)
|
||||
elif type(v) == np.ndarray:
|
||||
if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1):
|
||||
|
@ -193,7 +199,7 @@ class DenseVector(Vector):
|
|||
DenseVector([1.0, 0.0])
|
||||
"""
|
||||
def __init__(self, ar):
|
||||
if isinstance(ar, basestring):
|
||||
if isinstance(ar, bytes):
|
||||
ar = np.frombuffer(ar, dtype=np.float64)
|
||||
elif not isinstance(ar, np.ndarray):
|
||||
ar = np.array(ar, dtype=np.float64)
|
||||
|
@ -321,11 +327,13 @@ class DenseVector(Vector):
|
|||
__sub__ = _delegate("__sub__")
|
||||
__mul__ = _delegate("__mul__")
|
||||
__div__ = _delegate("__div__")
|
||||
__truediv__ = _delegate("__truediv__")
|
||||
__mod__ = _delegate("__mod__")
|
||||
__radd__ = _delegate("__radd__")
|
||||
__rsub__ = _delegate("__rsub__")
|
||||
__rmul__ = _delegate("__rmul__")
|
||||
__rdiv__ = _delegate("__rdiv__")
|
||||
__rtruediv__ = _delegate("__rtruediv__")
|
||||
__rmod__ = _delegate("__rmod__")
|
||||
|
||||
|
||||
|
@ -344,12 +352,12 @@ class SparseVector(Vector):
|
|||
:param args: Non-zero entries, as a dictionary, list of tupes,
|
||||
or two sorted lists containing indices and values.
|
||||
|
||||
>>> print SparseVector(4, {1: 1.0, 3: 5.5})
|
||||
(4,[1,3],[1.0,5.5])
|
||||
>>> print SparseVector(4, [(1, 1.0), (3, 5.5)])
|
||||
(4,[1,3],[1.0,5.5])
|
||||
>>> print SparseVector(4, [1, 3], [1.0, 5.5])
|
||||
(4,[1,3],[1.0,5.5])
|
||||
>>> SparseVector(4, {1: 1.0, 3: 5.5})
|
||||
SparseVector(4, {1: 1.0, 3: 5.5})
|
||||
>>> SparseVector(4, [(1, 1.0), (3, 5.5)])
|
||||
SparseVector(4, {1: 1.0, 3: 5.5})
|
||||
>>> SparseVector(4, [1, 3], [1.0, 5.5])
|
||||
SparseVector(4, {1: 1.0, 3: 5.5})
|
||||
"""
|
||||
self.size = int(size)
|
||||
assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments"
|
||||
|
@ -361,8 +369,8 @@ class SparseVector(Vector):
|
|||
self.indices = np.array([p[0] for p in pairs], dtype=np.int32)
|
||||
self.values = np.array([p[1] for p in pairs], dtype=np.float64)
|
||||
else:
|
||||
if isinstance(args[0], basestring):
|
||||
assert isinstance(args[1], str), "values should be string too"
|
||||
if isinstance(args[0], bytes):
|
||||
assert isinstance(args[1], bytes), "values should be string too"
|
||||
if args[0]:
|
||||
self.indices = np.frombuffer(args[0], np.int32)
|
||||
self.values = np.frombuffer(args[1], np.float64)
|
||||
|
@ -591,12 +599,12 @@ class Vectors(object):
|
|||
:param args: Non-zero entries, as a dictionary, list of tupes,
|
||||
or two sorted lists containing indices and values.
|
||||
|
||||
>>> print Vectors.sparse(4, {1: 1.0, 3: 5.5})
|
||||
(4,[1,3],[1.0,5.5])
|
||||
>>> print Vectors.sparse(4, [(1, 1.0), (3, 5.5)])
|
||||
(4,[1,3],[1.0,5.5])
|
||||
>>> print Vectors.sparse(4, [1, 3], [1.0, 5.5])
|
||||
(4,[1,3],[1.0,5.5])
|
||||
>>> Vectors.sparse(4, {1: 1.0, 3: 5.5})
|
||||
SparseVector(4, {1: 1.0, 3: 5.5})
|
||||
>>> Vectors.sparse(4, [(1, 1.0), (3, 5.5)])
|
||||
SparseVector(4, {1: 1.0, 3: 5.5})
|
||||
>>> Vectors.sparse(4, [1, 3], [1.0, 5.5])
|
||||
SparseVector(4, {1: 1.0, 3: 5.5})
|
||||
"""
|
||||
return SparseVector(size, *args)
|
||||
|
||||
|
@ -645,7 +653,7 @@ class Matrix(object):
|
|||
"""
|
||||
Convert Matrix attributes which are array-like or buffer to array.
|
||||
"""
|
||||
if isinstance(array_like, basestring):
|
||||
if isinstance(array_like, bytes):
|
||||
return np.frombuffer(array_like, dtype=dtype)
|
||||
return np.asarray(array_like, dtype=dtype)
|
||||
|
||||
|
@ -677,7 +685,7 @@ class DenseMatrix(Matrix):
|
|||
def toSparse(self):
|
||||
"""Convert to SparseMatrix"""
|
||||
indices = np.nonzero(self.values)[0]
|
||||
colCounts = np.bincount(indices / self.numRows)
|
||||
colCounts = np.bincount(indices // self.numRows)
|
||||
colPtrs = np.cumsum(np.hstack(
|
||||
(0, colCounts, np.zeros(self.numCols - colCounts.size))))
|
||||
values = self.values[indices]
|
||||
|
|
|
@ -88,10 +88,10 @@ class RandomRDDs(object):
|
|||
:param seed: Random seed (default: a random long integer).
|
||||
:return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0).
|
||||
|
||||
>>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L)
|
||||
>>> x = RandomRDDs.normalRDD(sc, 1000, seed=1)
|
||||
>>> stats = x.stats()
|
||||
>>> stats.count()
|
||||
1000L
|
||||
1000
|
||||
>>> abs(stats.mean() - 0.0) < 0.1
|
||||
True
|
||||
>>> abs(stats.stdev() - 1.0) < 0.1
|
||||
|
@ -118,10 +118,10 @@ class RandomRDDs(object):
|
|||
>>> std = 1.0
|
||||
>>> expMean = exp(mean + 0.5 * std * std)
|
||||
>>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std))
|
||||
>>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2L)
|
||||
>>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2)
|
||||
>>> stats = x.stats()
|
||||
>>> stats.count()
|
||||
1000L
|
||||
1000
|
||||
>>> abs(stats.mean() - expMean) < 0.5
|
||||
True
|
||||
>>> from math import sqrt
|
||||
|
@ -145,10 +145,10 @@ class RandomRDDs(object):
|
|||
:return: RDD of float comprised of i.i.d. samples ~ Pois(mean).
|
||||
|
||||
>>> mean = 100.0
|
||||
>>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L)
|
||||
>>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2)
|
||||
>>> stats = x.stats()
|
||||
>>> stats.count()
|
||||
1000L
|
||||
1000
|
||||
>>> abs(stats.mean() - mean) < 0.5
|
||||
True
|
||||
>>> from math import sqrt
|
||||
|
@ -171,10 +171,10 @@ class RandomRDDs(object):
|
|||
:return: RDD of float comprised of i.i.d. samples ~ Exp(mean).
|
||||
|
||||
>>> mean = 2.0
|
||||
>>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2L)
|
||||
>>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2)
|
||||
>>> stats = x.stats()
|
||||
>>> stats.count()
|
||||
1000L
|
||||
1000
|
||||
>>> abs(stats.mean() - mean) < 0.5
|
||||
True
|
||||
>>> from math import sqrt
|
||||
|
@ -202,10 +202,10 @@ class RandomRDDs(object):
|
|||
>>> scale = 2.0
|
||||
>>> expMean = shape * scale
|
||||
>>> expStd = sqrt(shape * scale * scale)
|
||||
>>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2L)
|
||||
>>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2)
|
||||
>>> stats = x.stats()
|
||||
>>> stats.count()
|
||||
1000L
|
||||
1000
|
||||
>>> abs(stats.mean() - expMean) < 0.5
|
||||
True
|
||||
>>> abs(stats.stdev() - expStd) < 0.5
|
||||
|
@ -254,7 +254,7 @@ class RandomRDDs(object):
|
|||
:return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`.
|
||||
|
||||
>>> import numpy as np
|
||||
>>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect())
|
||||
>>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect())
|
||||
>>> mat.shape
|
||||
(100, 100)
|
||||
>>> abs(mat.mean() - 0.0) < 0.1
|
||||
|
@ -286,8 +286,8 @@ class RandomRDDs(object):
|
|||
>>> std = 1.0
|
||||
>>> expMean = exp(mean + 0.5 * std * std)
|
||||
>>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std))
|
||||
>>> mat = np.matrix(RandomRDDs.logNormalVectorRDD(sc, mean, std, \
|
||||
100, 100, seed=1L).collect())
|
||||
>>> m = RandomRDDs.logNormalVectorRDD(sc, mean, std, 100, 100, seed=1).collect()
|
||||
>>> mat = np.matrix(m)
|
||||
>>> mat.shape
|
||||
(100, 100)
|
||||
>>> abs(mat.mean() - expMean) < 0.1
|
||||
|
@ -315,7 +315,7 @@ class RandomRDDs(object):
|
|||
|
||||
>>> import numpy as np
|
||||
>>> mean = 100.0
|
||||
>>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L)
|
||||
>>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1)
|
||||
>>> mat = np.mat(rdd.collect())
|
||||
>>> mat.shape
|
||||
(100, 100)
|
||||
|
@ -345,7 +345,7 @@ class RandomRDDs(object):
|
|||
|
||||
>>> import numpy as np
|
||||
>>> mean = 0.5
|
||||
>>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1L)
|
||||
>>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1)
|
||||
>>> mat = np.mat(rdd.collect())
|
||||
>>> mat.shape
|
||||
(100, 100)
|
||||
|
@ -380,8 +380,7 @@ class RandomRDDs(object):
|
|||
>>> scale = 2.0
|
||||
>>> expMean = shape * scale
|
||||
>>> expStd = sqrt(shape * scale * scale)
|
||||
>>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, \
|
||||
100, 100, seed=1L).collect())
|
||||
>>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect())
|
||||
>>> mat.shape
|
||||
(100, 100)
|
||||
>>> abs(mat.mean() - expMean) < 0.1
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import array
|
||||
from collections import namedtuple
|
||||
|
||||
from pyspark import SparkContext
|
||||
|
@ -104,14 +105,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
|
|||
assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)"
|
||||
first = user_product.first()
|
||||
assert len(first) == 2, "user_product should be RDD of (user, product)"
|
||||
user_product = user_product.map(lambda (u, p): (int(u), int(p)))
|
||||
user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1])))
|
||||
return self.call("predict", user_product)
|
||||
|
||||
def userFeatures(self):
|
||||
return self.call("getUserFeatures")
|
||||
return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v))
|
||||
|
||||
def productFeatures(self):
|
||||
return self.call("getProductFeatures")
|
||||
return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v))
|
||||
|
||||
@classmethod
|
||||
def load(cls, sc, path):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from pyspark import RDD
|
||||
from pyspark.rdd import RDD, ignore_unicode_prefix
|
||||
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
|
||||
from pyspark.mllib.linalg import Matrix, _convert_to_vector
|
||||
from pyspark.mllib.regression import LabeledPoint
|
||||
|
@ -38,7 +38,7 @@ class MultivariateStatisticalSummary(JavaModelWrapper):
|
|||
return self.call("variance").toArray()
|
||||
|
||||
def count(self):
|
||||
return self.call("count")
|
||||
return int(self.call("count"))
|
||||
|
||||
def numNonzeros(self):
|
||||
return self.call("numNonzeros").toArray()
|
||||
|
@ -78,7 +78,7 @@ class Statistics(object):
|
|||
>>> cStats.variance()
|
||||
array([ 4., 13., 0., 25.])
|
||||
>>> cStats.count()
|
||||
3L
|
||||
3
|
||||
>>> cStats.numNonzeros()
|
||||
array([ 3., 2., 0., 3.])
|
||||
>>> cStats.max()
|
||||
|
@ -124,20 +124,20 @@ class Statistics(object):
|
|||
>>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]),
|
||||
... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])])
|
||||
>>> pearsonCorr = Statistics.corr(rdd)
|
||||
>>> print str(pearsonCorr).replace('nan', 'NaN')
|
||||
>>> print(str(pearsonCorr).replace('nan', 'NaN'))
|
||||
[[ 1. 0.05564149 NaN 0.40047142]
|
||||
[ 0.05564149 1. NaN 0.91359586]
|
||||
[ NaN NaN 1. NaN]
|
||||
[ 0.40047142 0.91359586 NaN 1. ]]
|
||||
>>> spearmanCorr = Statistics.corr(rdd, method="spearman")
|
||||
>>> print str(spearmanCorr).replace('nan', 'NaN')
|
||||
>>> print(str(spearmanCorr).replace('nan', 'NaN'))
|
||||
[[ 1. 0.10540926 NaN 0.4 ]
|
||||
[ 0.10540926 1. NaN 0.9486833 ]
|
||||
[ NaN NaN 1. NaN]
|
||||
[ 0.4 0.9486833 NaN 1. ]]
|
||||
>>> try:
|
||||
... Statistics.corr(rdd, "spearman")
|
||||
... print "Method name as second argument without 'method=' shouldn't be allowed."
|
||||
... print("Method name as second argument without 'method=' shouldn't be allowed.")
|
||||
... except TypeError:
|
||||
... pass
|
||||
"""
|
||||
|
@ -153,6 +153,7 @@ class Statistics(object):
|
|||
return callMLlibFunc("corr", x.map(float), y.map(float), method)
|
||||
|
||||
@staticmethod
|
||||
@ignore_unicode_prefix
|
||||
def chiSqTest(observed, expected=None):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
@ -188,11 +189,11 @@ class Statistics(object):
|
|||
>>> from pyspark.mllib.linalg import Vectors, Matrices
|
||||
>>> observed = Vectors.dense([4, 6, 5])
|
||||
>>> pearson = Statistics.chiSqTest(observed)
|
||||
>>> print pearson.statistic
|
||||
>>> print(pearson.statistic)
|
||||
0.4
|
||||
>>> pearson.degreesOfFreedom
|
||||
2
|
||||
>>> print round(pearson.pValue, 4)
|
||||
>>> print(round(pearson.pValue, 4))
|
||||
0.8187
|
||||
>>> pearson.method
|
||||
u'pearson'
|
||||
|
@ -202,12 +203,12 @@ class Statistics(object):
|
|||
>>> observed = Vectors.dense([21, 38, 43, 80])
|
||||
>>> expected = Vectors.dense([3, 5, 7, 20])
|
||||
>>> pearson = Statistics.chiSqTest(observed, expected)
|
||||
>>> print round(pearson.pValue, 4)
|
||||
>>> print(round(pearson.pValue, 4))
|
||||
0.0027
|
||||
|
||||
>>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0]
|
||||
>>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data))
|
||||
>>> print round(chi.statistic, 4)
|
||||
>>> print(round(chi.statistic, 4))
|
||||
21.9958
|
||||
|
||||
>>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])),
|
||||
|
@ -218,9 +219,9 @@ class Statistics(object):
|
|||
... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),]
|
||||
>>> rdd = sc.parallelize(data, 4)
|
||||
>>> chi = Statistics.chiSqTest(rdd)
|
||||
>>> print chi[0].statistic
|
||||
>>> print(chi[0].statistic)
|
||||
0.75
|
||||
>>> print chi[1].statistic
|
||||
>>> print(chi[1].statistic)
|
||||
1.5
|
||||
"""
|
||||
if isinstance(observed, RDD):
|
||||
|
|
|
@ -72,11 +72,11 @@ class VectorTests(PySparkTestCase):
|
|||
def _test_serialize(self, v):
|
||||
self.assertEqual(v, ser.loads(ser.dumps(v)))
|
||||
jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
|
||||
nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec)))
|
||||
nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec)))
|
||||
self.assertEqual(v, nv)
|
||||
vs = [v] * 100
|
||||
jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs)))
|
||||
nvs = ser.loads(str(self.sc._jvm.SerDe.dumps(jvecs)))
|
||||
nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs)))
|
||||
self.assertEqual(vs, nvs)
|
||||
|
||||
def test_serialize(self):
|
||||
|
@ -412,11 +412,11 @@ class StatTests(PySparkTestCase):
|
|||
self.assertEqual(10, len(summary.normL1()))
|
||||
self.assertEqual(10, len(summary.normL2()))
|
||||
|
||||
data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x))
|
||||
data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x))
|
||||
summary2 = Statistics.colStats(data2)
|
||||
self.assertEqual(array([45.0]), summary2.normL1())
|
||||
import math
|
||||
expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, xrange(10))))
|
||||
expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10))))
|
||||
self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14)
|
||||
|
||||
|
||||
|
@ -438,11 +438,11 @@ class VectorUDTTests(PySparkTestCase):
|
|||
def test_infer_schema(self):
|
||||
sqlCtx = SQLContext(self.sc)
|
||||
rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
|
||||
srdd = sqlCtx.inferSchema(rdd)
|
||||
schema = srdd.schema
|
||||
df = rdd.toDF()
|
||||
schema = df.schema
|
||||
field = [f for f in schema.fields if f.name == "features"][0]
|
||||
self.assertEqual(field.dataType, self.udt)
|
||||
vectors = srdd.map(lambda p: p.features).collect()
|
||||
vectors = df.map(lambda p: p.features).collect()
|
||||
self.assertEqual(len(vectors), 2)
|
||||
for v in vectors:
|
||||
if isinstance(v, SparseVector):
|
||||
|
@ -695,7 +695,7 @@ class ChiSqTestTests(PySparkTestCase):
|
|||
|
||||
class SerDeTest(PySparkTestCase):
|
||||
def test_to_java_object_rdd(self): # SPARK-6660
|
||||
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
|
||||
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0)
|
||||
self.assertEqual(_to_java_object_rdd(data).count(), 10)
|
||||
|
||||
|
||||
|
@ -771,7 +771,7 @@ class StandardScalerTests(PySparkTestCase):
|
|||
|
||||
if __name__ == "__main__":
|
||||
if not _have_scipy:
|
||||
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
|
||||
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
|
||||
unittest.main()
|
||||
if not _have_scipy:
|
||||
print "NOTE: SciPy tests were skipped as it does not seem to be installed"
|
||||
print("NOTE: SciPy tests were skipped as it does not seem to be installed")
|
||||
|
|
|
@ -163,14 +163,16 @@ class DecisionTree(object):
|
|||
... LabeledPoint(1.0, [3.0])
|
||||
... ]
|
||||
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
|
||||
>>> print model, # it already has newline
|
||||
>>> print(model)
|
||||
DecisionTreeModel classifier of depth 1 with 3 nodes
|
||||
>>> print model.toDebugString(), # it already has newline
|
||||
|
||||
>>> print(model.toDebugString())
|
||||
DecisionTreeModel classifier of depth 1 with 3 nodes
|
||||
If (feature 0 <= 0.0)
|
||||
Predict: 0.0
|
||||
Else (feature 0 > 0.0)
|
||||
Predict: 1.0
|
||||
<BLANKLINE>
|
||||
>>> model.predict(array([1.0]))
|
||||
1.0
|
||||
>>> model.predict(array([0.0]))
|
||||
|
@ -318,9 +320,10 @@ class RandomForest(object):
|
|||
3
|
||||
>>> model.totalNumNodes()
|
||||
7
|
||||
>>> print model,
|
||||
>>> print(model)
|
||||
TreeEnsembleModel classifier with 3 trees
|
||||
>>> print model.toDebugString(),
|
||||
<BLANKLINE>
|
||||
>>> print(model.toDebugString())
|
||||
TreeEnsembleModel classifier with 3 trees
|
||||
<BLANKLINE>
|
||||
Tree 0:
|
||||
|
@ -335,6 +338,7 @@ class RandomForest(object):
|
|||
Predict: 0.0
|
||||
Else (feature 0 > 1.0)
|
||||
Predict: 1.0
|
||||
<BLANKLINE>
|
||||
>>> model.predict([2.0])
|
||||
1.0
|
||||
>>> model.predict([0.0])
|
||||
|
@ -483,8 +487,9 @@ class GradientBoostedTrees(object):
|
|||
100
|
||||
>>> model.totalNumNodes()
|
||||
300
|
||||
>>> print model, # it already has newline
|
||||
>>> print(model) # it already has newline
|
||||
TreeEnsembleModel classifier with 100 trees
|
||||
<BLANKLINE>
|
||||
>>> model.predict([2.0])
|
||||
1.0
|
||||
>>> model.predict([0.0])
|
||||
|
|
|
@ -15,10 +15,14 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
|
||||
if sys.version > '3':
|
||||
xrange = range
|
||||
|
||||
from pyspark.mllib.common import callMLlibFunc, inherit_doc
|
||||
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
|
||||
|
||||
|
||||
|
@ -94,22 +98,16 @@ class MLUtils(object):
|
|||
>>> from pyspark.mllib.util import MLUtils
|
||||
>>> from pyspark.mllib.regression import LabeledPoint
|
||||
>>> tempFile = NamedTemporaryFile(delete=True)
|
||||
>>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
|
||||
>>> _ = tempFile.write(b"+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
|
||||
>>> tempFile.flush()
|
||||
>>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect()
|
||||
>>> tempFile.close()
|
||||
>>> type(examples[0]) == LabeledPoint
|
||||
True
|
||||
>>> print examples[0]
|
||||
(1.0,(6,[0,2,4],[1.0,2.0,3.0]))
|
||||
>>> type(examples[1]) == LabeledPoint
|
||||
True
|
||||
>>> print examples[1]
|
||||
(-1.0,(6,[],[]))
|
||||
>>> type(examples[2]) == LabeledPoint
|
||||
True
|
||||
>>> print examples[2]
|
||||
(-1.0,(6,[1,3,5],[4.0,5.0,6.0]))
|
||||
>>> examples[0]
|
||||
LabeledPoint(1.0, (6,[0,2,4],[1.0,2.0,3.0]))
|
||||
>>> examples[1]
|
||||
LabeledPoint(-1.0, (6,[],[]))
|
||||
>>> examples[2]
|
||||
LabeledPoint(-1.0, (6,[1,3,5],[4.0,5.0,6.0]))
|
||||
"""
|
||||
from pyspark.mllib.regression import LabeledPoint
|
||||
if multiclass is not None:
|
||||
|
|
|
@ -84,11 +84,11 @@ class Profiler(object):
|
|||
>>> from pyspark import BasicProfiler
|
||||
>>> class MyCustomProfiler(BasicProfiler):
|
||||
... def show(self, id):
|
||||
... print "My custom profiles for RDD:%s" % id
|
||||
... print("My custom profiles for RDD:%s" % id)
|
||||
...
|
||||
>>> conf = SparkConf().set("spark.python.profile", "true")
|
||||
>>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
|
||||
>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
|
||||
>>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10)
|
||||
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
|
||||
>>> sc.show_profiles()
|
||||
My custom profiles for RDD:1
|
||||
|
@ -111,9 +111,9 @@ class Profiler(object):
|
|||
""" Print the profile stats to stdout, id is the RDD id """
|
||||
stats = self.stats()
|
||||
if stats:
|
||||
print "=" * 60
|
||||
print "Profile of RDD<id=%d>" % id
|
||||
print "=" * 60
|
||||
print("=" * 60)
|
||||
print("Profile of RDD<id=%d>" % id)
|
||||
print("=" * 60)
|
||||
stats.sort_stats("time", "cumulative").print_stats()
|
||||
|
||||
def dump(self, id, path):
|
||||
|
|
|
@ -16,21 +16,29 @@
|
|||
#
|
||||
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
from itertools import chain, ifilter, imap
|
||||
import operator
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import operator
|
||||
import shlex
|
||||
from subprocess import Popen, PIPE
|
||||
from tempfile import NamedTemporaryFile
|
||||
from threading import Thread
|
||||
import warnings
|
||||
import heapq
|
||||
import bisect
|
||||
import random
|
||||
import socket
|
||||
from subprocess import Popen, PIPE
|
||||
from tempfile import NamedTemporaryFile
|
||||
from threading import Thread
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from functools import reduce
|
||||
from math import sqrt, log, isinf, isnan, pow, ceil
|
||||
|
||||
if sys.version > '3':
|
||||
basestring = unicode = str
|
||||
else:
|
||||
from itertools import imap as map, ifilter as filter
|
||||
|
||||
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
|
||||
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
|
||||
PickleSerializer, pack_long, AutoBatchedSerializer
|
||||
|
@ -50,20 +58,21 @@ from py4j.java_collections import ListConverter, MapConverter
|
|||
__all__ = ["RDD"]
|
||||
|
||||
|
||||
# TODO: for Python 3.3+, PYTHONHASHSEED should be reset to disable randomized
|
||||
# hash for string
|
||||
def portable_hash(x):
|
||||
"""
|
||||
This function returns consistant hash code for builtin types, especially
|
||||
This function returns consistent hash code for builtin types, especially
|
||||
for None and tuple with None.
|
||||
|
||||
The algrithm is similar to that one used by CPython 2.7
|
||||
The algorithm is similar to that one used by CPython 2.7
|
||||
|
||||
>>> portable_hash(None)
|
||||
0
|
||||
>>> portable_hash((None, 1)) & 0xffffffff
|
||||
219750521
|
||||
"""
|
||||
if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ:
|
||||
raise Exception("Randomness of hash of string should be disabled via PYTHONHASHSEED")
|
||||
|
||||
if x is None:
|
||||
return 0
|
||||
if isinstance(x, tuple):
|
||||
|
@ -71,7 +80,7 @@ def portable_hash(x):
|
|||
for i in x:
|
||||
h ^= portable_hash(i)
|
||||
h *= 1000003
|
||||
h &= sys.maxint
|
||||
h &= sys.maxsize
|
||||
h ^= len(x)
|
||||
if h == -1:
|
||||
h = -2
|
||||
|
@ -123,6 +132,19 @@ def _load_from_socket(port, serializer):
|
|||
sock.close()
|
||||
|
||||
|
||||
def ignore_unicode_prefix(f):
|
||||
"""
|
||||
Ignore the 'u' prefix of string in doc tests, to make it works
|
||||
in both python 2 and 3
|
||||
"""
|
||||
if sys.version >= '3':
|
||||
# the representation of unicode string in Python 3 does not have prefix 'u',
|
||||
# so remove the prefix 'u' for doc tests
|
||||
literal_re = re.compile(r"(\W|^)[uU](['])", re.UNICODE)
|
||||
f.__doc__ = literal_re.sub(r'\1\2', f.__doc__)
|
||||
return f
|
||||
|
||||
|
||||
class Partitioner(object):
|
||||
def __init__(self, numPartitions, partitionFunc):
|
||||
self.numPartitions = numPartitions
|
||||
|
@ -251,7 +273,7 @@ class RDD(object):
|
|||
[('a', 1), ('b', 1), ('c', 1)]
|
||||
"""
|
||||
def func(_, iterator):
|
||||
return imap(f, iterator)
|
||||
return map(f, iterator)
|
||||
return self.mapPartitionsWithIndex(func, preservesPartitioning)
|
||||
|
||||
def flatMap(self, f, preservesPartitioning=False):
|
||||
|
@ -266,7 +288,7 @@ class RDD(object):
|
|||
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
|
||||
"""
|
||||
def func(s, iterator):
|
||||
return chain.from_iterable(imap(f, iterator))
|
||||
return chain.from_iterable(map(f, iterator))
|
||||
return self.mapPartitionsWithIndex(func, preservesPartitioning)
|
||||
|
||||
def mapPartitions(self, f, preservesPartitioning=False):
|
||||
|
@ -329,7 +351,7 @@ class RDD(object):
|
|||
[2, 4]
|
||||
"""
|
||||
def func(iterator):
|
||||
return ifilter(f, iterator)
|
||||
return filter(f, iterator)
|
||||
return self.mapPartitions(func, True)
|
||||
|
||||
def distinct(self, numPartitions=None):
|
||||
|
@ -341,7 +363,7 @@ class RDD(object):
|
|||
"""
|
||||
return self.map(lambda x: (x, None)) \
|
||||
.reduceByKey(lambda x, _: x, numPartitions) \
|
||||
.map(lambda (x, _): x)
|
||||
.map(lambda x: x[0])
|
||||
|
||||
def sample(self, withReplacement, fraction, seed=None):
|
||||
"""
|
||||
|
@ -354,8 +376,8 @@ class RDD(object):
|
|||
:param seed: seed for the random number generator
|
||||
|
||||
>>> rdd = sc.parallelize(range(100), 4)
|
||||
>>> rdd.sample(False, 0.1, 81).count()
|
||||
10
|
||||
>>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14
|
||||
True
|
||||
"""
|
||||
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
|
||||
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
|
||||
|
@ -368,12 +390,14 @@ class RDD(object):
|
|||
:param seed: random seed
|
||||
:return: split RDDs in a list
|
||||
|
||||
>>> rdd = sc.parallelize(range(5), 1)
|
||||
>>> rdd = sc.parallelize(range(500), 1)
|
||||
>>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17)
|
||||
>>> rdd1.collect()
|
||||
[1, 3]
|
||||
>>> rdd2.collect()
|
||||
[0, 2, 4]
|
||||
>>> len(rdd1.collect() + rdd2.collect())
|
||||
500
|
||||
>>> 150 < rdd1.count() < 250
|
||||
True
|
||||
>>> 250 < rdd2.count() < 350
|
||||
True
|
||||
"""
|
||||
s = float(sum(weights))
|
||||
cweights = [0.0]
|
||||
|
@ -416,7 +440,7 @@ class RDD(object):
|
|||
rand.shuffle(samples)
|
||||
return samples
|
||||
|
||||
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
|
||||
maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize))
|
||||
if num > maxSampleSize:
|
||||
raise ValueError(
|
||||
"Sample size cannot be greater than %d." % maxSampleSize)
|
||||
|
@ -430,7 +454,7 @@ class RDD(object):
|
|||
# See: scala/spark/RDD.scala
|
||||
while len(samples) < num:
|
||||
# TODO: add log warning for when more than one iteration was run
|
||||
seed = rand.randint(0, sys.maxint)
|
||||
seed = rand.randint(0, sys.maxsize)
|
||||
samples = self.sample(withReplacement, fraction, seed).collect()
|
||||
|
||||
rand.shuffle(samples)
|
||||
|
@ -507,7 +531,7 @@ class RDD(object):
|
|||
"""
|
||||
return self.map(lambda v: (v, None)) \
|
||||
.cogroup(other.map(lambda v: (v, None))) \
|
||||
.filter(lambda (k, vs): all(vs)) \
|
||||
.filter(lambda k_vs: all(k_vs[1])) \
|
||||
.keys()
|
||||
|
||||
def _reserialize(self, serializer=None):
|
||||
|
@ -549,7 +573,7 @@ class RDD(object):
|
|||
|
||||
def sortPartition(iterator):
|
||||
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
|
||||
return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))
|
||||
return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending)))
|
||||
|
||||
return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True)
|
||||
|
||||
|
@ -579,7 +603,7 @@ class RDD(object):
|
|||
|
||||
def sortPartition(iterator):
|
||||
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
|
||||
return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))
|
||||
return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending)))
|
||||
|
||||
if numPartitions == 1:
|
||||
if self.getNumPartitions() > 1:
|
||||
|
@ -594,12 +618,12 @@ class RDD(object):
|
|||
return self # empty RDD
|
||||
maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
|
||||
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
|
||||
samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
|
||||
samples = self.sample(False, fraction, 1).map(lambda kv: kv[0]).collect()
|
||||
samples = sorted(samples, key=keyfunc)
|
||||
|
||||
# we have numPartitions many parts but one of the them has
|
||||
# an implicit boundary
|
||||
bounds = [samples[len(samples) * (i + 1) / numPartitions]
|
||||
bounds = [samples[int(len(samples) * (i + 1) / numPartitions)]
|
||||
for i in range(0, numPartitions - 1)]
|
||||
|
||||
def rangePartitioner(k):
|
||||
|
@ -662,12 +686,13 @@ class RDD(object):
|
|||
"""
|
||||
return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def pipe(self, command, env={}):
|
||||
"""
|
||||
Return an RDD created by piping elements to a forked external process.
|
||||
|
||||
>>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect()
|
||||
['1', '2', '', '3']
|
||||
[u'1', u'2', u'', u'3']
|
||||
"""
|
||||
def func(iterator):
|
||||
pipe = Popen(
|
||||
|
@ -675,17 +700,18 @@ class RDD(object):
|
|||
|
||||
def pipe_objs(out):
|
||||
for obj in iterator:
|
||||
out.write(str(obj).rstrip('\n') + '\n')
|
||||
s = str(obj).rstrip('\n') + '\n'
|
||||
out.write(s.encode('utf-8'))
|
||||
out.close()
|
||||
Thread(target=pipe_objs, args=[pipe.stdin]).start()
|
||||
return (x.rstrip('\n') for x in iter(pipe.stdout.readline, ''))
|
||||
return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b''))
|
||||
return self.mapPartitions(func)
|
||||
|
||||
def foreach(self, f):
|
||||
"""
|
||||
Applies a function to all elements of this RDD.
|
||||
|
||||
>>> def f(x): print x
|
||||
>>> def f(x): print(x)
|
||||
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
|
||||
"""
|
||||
def processPartition(iterator):
|
||||
|
@ -700,7 +726,7 @@ class RDD(object):
|
|||
|
||||
>>> def f(iterator):
|
||||
... for x in iterator:
|
||||
... print x
|
||||
... print(x)
|
||||
>>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f)
|
||||
"""
|
||||
def func(it):
|
||||
|
@ -874,7 +900,7 @@ class RDD(object):
|
|||
# aggregation.
|
||||
while numPartitions > scale + numPartitions / scale:
|
||||
numPartitions /= scale
|
||||
curNumPartitions = numPartitions
|
||||
curNumPartitions = int(numPartitions)
|
||||
|
||||
def mapPartition(i, iterator):
|
||||
for obj in iterator:
|
||||
|
@ -984,7 +1010,7 @@ class RDD(object):
|
|||
(('a', 'b', 'c'), [2, 2])
|
||||
"""
|
||||
|
||||
if isinstance(buckets, (int, long)):
|
||||
if isinstance(buckets, int):
|
||||
if buckets < 1:
|
||||
raise ValueError("number of buckets must be >= 1")
|
||||
|
||||
|
@ -1020,6 +1046,7 @@ class RDD(object):
|
|||
raise ValueError("Can not generate buckets with infinite value")
|
||||
|
||||
# keep them as integer if possible
|
||||
inc = int(inc)
|
||||
if inc * buckets != maxv - minv:
|
||||
inc = (maxv - minv) * 1.0 / buckets
|
||||
|
||||
|
@ -1137,7 +1164,7 @@ class RDD(object):
|
|||
yield counts
|
||||
|
||||
def mergeMaps(m1, m2):
|
||||
for k, v in m2.iteritems():
|
||||
for k, v in m2.items():
|
||||
m1[k] += v
|
||||
return m1
|
||||
return self.mapPartitions(countPartition).reduce(mergeMaps)
|
||||
|
@ -1378,8 +1405,8 @@ class RDD(object):
|
|||
>>> tmpFile = NamedTemporaryFile(delete=True)
|
||||
>>> tmpFile.close()
|
||||
>>> sc.parallelize([1, 2, 'spark', 'rdd']).saveAsPickleFile(tmpFile.name, 3)
|
||||
>>> sorted(sc.pickleFile(tmpFile.name, 5).collect())
|
||||
[1, 2, 'rdd', 'spark']
|
||||
>>> sorted(sc.pickleFile(tmpFile.name, 5).map(str).collect())
|
||||
['1', '2', 'rdd', 'spark']
|
||||
"""
|
||||
if batchSize == 0:
|
||||
ser = AutoBatchedSerializer(PickleSerializer())
|
||||
|
@ -1387,6 +1414,7 @@ class RDD(object):
|
|||
ser = BatchedSerializer(PickleSerializer(), batchSize)
|
||||
self._reserialize(ser)._jrdd.saveAsObjectFile(path)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def saveAsTextFile(self, path, compressionCodecClass=None):
|
||||
"""
|
||||
Save this RDD as a text file, using string representations of elements.
|
||||
|
@ -1418,12 +1446,13 @@ class RDD(object):
|
|||
>>> codec = "org.apache.hadoop.io.compress.GzipCodec"
|
||||
>>> sc.parallelize(['foo', 'bar']).saveAsTextFile(tempFile3.name, codec)
|
||||
>>> from fileinput import input, hook_compressed
|
||||
>>> ''.join(sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed)))
|
||||
'bar\\nfoo\\n'
|
||||
>>> result = sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed))
|
||||
>>> b''.join(result).decode('utf-8')
|
||||
u'bar\\nfoo\\n'
|
||||
"""
|
||||
def func(split, iterator):
|
||||
for x in iterator:
|
||||
if not isinstance(x, basestring):
|
||||
if not isinstance(x, (unicode, bytes)):
|
||||
x = unicode(x)
|
||||
if isinstance(x, unicode):
|
||||
x = x.encode("utf-8")
|
||||
|
@ -1458,7 +1487,7 @@ class RDD(object):
|
|||
>>> m.collect()
|
||||
[1, 3]
|
||||
"""
|
||||
return self.map(lambda (k, v): k)
|
||||
return self.map(lambda x: x[0])
|
||||
|
||||
def values(self):
|
||||
"""
|
||||
|
@ -1468,7 +1497,7 @@ class RDD(object):
|
|||
>>> m.collect()
|
||||
[2, 4]
|
||||
"""
|
||||
return self.map(lambda (k, v): v)
|
||||
return self.map(lambda x: x[1])
|
||||
|
||||
def reduceByKey(self, func, numPartitions=None):
|
||||
"""
|
||||
|
@ -1507,7 +1536,7 @@ class RDD(object):
|
|||
yield m
|
||||
|
||||
def mergeMaps(m1, m2):
|
||||
for k, v in m2.iteritems():
|
||||
for k, v in m2.items():
|
||||
m1[k] = func(m1[k], v) if k in m1 else v
|
||||
return m1
|
||||
return self.mapPartitions(reducePartition).reduce(mergeMaps)
|
||||
|
@ -1604,8 +1633,8 @@ class RDD(object):
|
|||
|
||||
>>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
|
||||
>>> sets = pairs.partitionBy(2).glom().collect()
|
||||
>>> set(sets[0]).intersection(set(sets[1]))
|
||||
set([])
|
||||
>>> len(set(sets[0]).intersection(set(sets[1])))
|
||||
0
|
||||
"""
|
||||
if numPartitions is None:
|
||||
numPartitions = self._defaultReducePartitions()
|
||||
|
@ -1637,22 +1666,22 @@ class RDD(object):
|
|||
if (c % 1000 == 0 and get_used_memory() > limit
|
||||
or c > batch):
|
||||
n, size = len(buckets), 0
|
||||
for split in buckets.keys():
|
||||
for split in list(buckets.keys()):
|
||||
yield pack_long(split)
|
||||
d = outputSerializer.dumps(buckets[split])
|
||||
del buckets[split]
|
||||
yield d
|
||||
size += len(d)
|
||||
|
||||
avg = (size / n) >> 20
|
||||
avg = int(size / n) >> 20
|
||||
# let 1M < avg < 10M
|
||||
if avg < 1:
|
||||
batch *= 1.5
|
||||
elif avg > 10:
|
||||
batch = max(batch / 1.5, 1)
|
||||
batch = max(int(batch / 1.5), 1)
|
||||
c = 0
|
||||
|
||||
for split, items in buckets.iteritems():
|
||||
for split, items in buckets.items():
|
||||
yield pack_long(split)
|
||||
yield outputSerializer.dumps(items)
|
||||
|
||||
|
@ -1707,7 +1736,7 @@ class RDD(object):
|
|||
merger = ExternalMerger(agg, memory * 0.9, serializer) \
|
||||
if spill else InMemoryMerger(agg)
|
||||
merger.mergeValues(iterator)
|
||||
return merger.iteritems()
|
||||
return merger.items()
|
||||
|
||||
locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True)
|
||||
shuffled = locally_combined.partitionBy(numPartitions)
|
||||
|
@ -1716,7 +1745,7 @@ class RDD(object):
|
|||
merger = ExternalMerger(agg, memory, serializer) \
|
||||
if spill else InMemoryMerger(agg)
|
||||
merger.mergeCombiners(iterator)
|
||||
return merger.iteritems()
|
||||
return merger.items()
|
||||
|
||||
return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True)
|
||||
|
||||
|
@ -1745,7 +1774,7 @@ class RDD(object):
|
|||
|
||||
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
|
||||
>>> from operator import add
|
||||
>>> rdd.foldByKey(0, add).collect()
|
||||
>>> sorted(rdd.foldByKey(0, add).collect())
|
||||
[('a', 2), ('b', 1)]
|
||||
"""
|
||||
def createZero():
|
||||
|
@ -1769,10 +1798,10 @@ class RDD(object):
|
|||
sum or average) over each key, using reduceByKey or aggregateByKey will
|
||||
provide much better performance.
|
||||
|
||||
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
|
||||
>>> sorted(x.groupByKey().mapValues(len).collect())
|
||||
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
|
||||
>>> sorted(rdd.groupByKey().mapValues(len).collect())
|
||||
[('a', 2), ('b', 1)]
|
||||
>>> sorted(x.groupByKey().mapValues(list).collect())
|
||||
>>> sorted(rdd.groupByKey().mapValues(list).collect())
|
||||
[('a', [1, 1]), ('b', [1])]
|
||||
"""
|
||||
def createCombiner(x):
|
||||
|
@ -1795,7 +1824,7 @@ class RDD(object):
|
|||
merger = ExternalMerger(agg, memory * 0.9, serializer) \
|
||||
if spill else InMemoryMerger(agg)
|
||||
merger.mergeValues(iterator)
|
||||
return merger.iteritems()
|
||||
return merger.items()
|
||||
|
||||
locally_combined = self.mapPartitions(combine, preservesPartitioning=True)
|
||||
shuffled = locally_combined.partitionBy(numPartitions)
|
||||
|
@ -1804,7 +1833,7 @@ class RDD(object):
|
|||
merger = ExternalGroupBy(agg, memory, serializer)\
|
||||
if spill else InMemoryMerger(agg)
|
||||
merger.mergeCombiners(it)
|
||||
return merger.iteritems()
|
||||
return merger.items()
|
||||
|
||||
return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable)
|
||||
|
||||
|
@ -1819,7 +1848,7 @@ class RDD(object):
|
|||
>>> x.flatMapValues(f).collect()
|
||||
[('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')]
|
||||
"""
|
||||
flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
|
||||
flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1]))
|
||||
return self.flatMap(flat_map_fn, preservesPartitioning=True)
|
||||
|
||||
def mapValues(self, f):
|
||||
|
@ -1833,7 +1862,7 @@ class RDD(object):
|
|||
>>> x.mapValues(f).collect()
|
||||
[('a', 3), ('b', 1)]
|
||||
"""
|
||||
map_values_fn = lambda (k, v): (k, f(v))
|
||||
map_values_fn = lambda kv: (kv[0], f(kv[1]))
|
||||
return self.map(map_values_fn, preservesPartitioning=True)
|
||||
|
||||
def groupWith(self, other, *others):
|
||||
|
@ -1844,8 +1873,7 @@ class RDD(object):
|
|||
>>> x = sc.parallelize([("a", 1), ("b", 4)])
|
||||
>>> y = sc.parallelize([("a", 2)])
|
||||
>>> z = sc.parallelize([("b", 42)])
|
||||
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \
|
||||
sorted(list(w.groupWith(x, y, z).collect())))
|
||||
>>> [(x, tuple(map(list, y))) for x, y in sorted(list(w.groupWith(x, y, z).collect()))]
|
||||
[('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))]
|
||||
|
||||
"""
|
||||
|
@ -1860,7 +1888,7 @@ class RDD(object):
|
|||
|
||||
>>> x = sc.parallelize([("a", 1), ("b", 4)])
|
||||
>>> y = sc.parallelize([("a", 2)])
|
||||
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect())))
|
||||
>>> [(x, tuple(map(list, y))) for x, y in sorted(list(x.cogroup(y).collect()))]
|
||||
[('a', ([1], [2])), ('b', ([4], []))]
|
||||
"""
|
||||
return python_cogroup((self, other), numPartitions)
|
||||
|
@ -1896,8 +1924,9 @@ class RDD(object):
|
|||
>>> sorted(x.subtractByKey(y).collect())
|
||||
[('b', 4), ('b', 5)]
|
||||
"""
|
||||
def filter_func((key, vals)):
|
||||
return vals[0] and not vals[1]
|
||||
def filter_func(pair):
|
||||
key, (val1, val2) = pair
|
||||
return val1 and not val2
|
||||
return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0])
|
||||
|
||||
def subtract(self, other, numPartitions=None):
|
||||
|
@ -1919,8 +1948,8 @@ class RDD(object):
|
|||
|
||||
>>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x)
|
||||
>>> y = sc.parallelize(zip(range(0,5), range(0,5)))
|
||||
>>> map((lambda (x,y): (x, (list(y[0]), (list(y[1]))))), sorted(x.cogroup(y).collect()))
|
||||
[(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))]
|
||||
>>> [(x, list(map(list, y))) for x, y in sorted(x.cogroup(y).collect())]
|
||||
[(0, [[0], [0]]), (1, [[1], [1]]), (2, [[], [2]]), (3, [[], [3]]), (4, [[2], [4]])]
|
||||
"""
|
||||
return self.map(lambda x: (f(x), x))
|
||||
|
||||
|
@ -2049,17 +2078,18 @@ class RDD(object):
|
|||
"""
|
||||
Return the name of this RDD.
|
||||
"""
|
||||
name_ = self._jrdd.name()
|
||||
if name_:
|
||||
return name_.encode('utf-8')
|
||||
n = self._jrdd.name()
|
||||
if n:
|
||||
return n
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def setName(self, name):
|
||||
"""
|
||||
Assign a name to this RDD.
|
||||
|
||||
>>> rdd1 = sc.parallelize([1,2])
|
||||
>>> rdd1 = sc.parallelize([1, 2])
|
||||
>>> rdd1.setName('RDD1').name()
|
||||
'RDD1'
|
||||
u'RDD1'
|
||||
"""
|
||||
self._jrdd.setName(name)
|
||||
return self
|
||||
|
@ -2121,7 +2151,7 @@ class RDD(object):
|
|||
>>> sorted.lookup(1024)
|
||||
[]
|
||||
"""
|
||||
values = self.filter(lambda (k, v): k == key).values()
|
||||
values = self.filter(lambda kv: kv[0] == key).values()
|
||||
|
||||
if self.partitioner is not None:
|
||||
return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False)
|
||||
|
@ -2159,7 +2189,7 @@ class RDD(object):
|
|||
or meet the confidence.
|
||||
|
||||
>>> rdd = sc.parallelize(range(1000), 10)
|
||||
>>> r = sum(xrange(1000))
|
||||
>>> r = sum(range(1000))
|
||||
>>> (rdd.sumApprox(1000) - r) / r < 0.05
|
||||
True
|
||||
"""
|
||||
|
@ -2176,7 +2206,7 @@ class RDD(object):
|
|||
or meet the confidence.
|
||||
|
||||
>>> rdd = sc.parallelize(range(1000), 10)
|
||||
>>> r = sum(xrange(1000)) / 1000.0
|
||||
>>> r = sum(range(1000)) / 1000.0
|
||||
>>> (rdd.meanApprox(1000) - r) / r < 0.05
|
||||
True
|
||||
"""
|
||||
|
@ -2201,10 +2231,10 @@ class RDD(object):
|
|||
It must be greater than 0.000017.
|
||||
|
||||
>>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct()
|
||||
>>> 950 < n < 1050
|
||||
>>> 900 < n < 1100
|
||||
True
|
||||
>>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct()
|
||||
>>> 18 < n < 22
|
||||
>>> 16 < n < 24
|
||||
True
|
||||
"""
|
||||
if relativeSD < 0.000017:
|
||||
|
@ -2223,8 +2253,7 @@ class RDD(object):
|
|||
>>> [x for x in rdd.toLocalIterator()]
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
"""
|
||||
partitions = xrange(self.getNumPartitions())
|
||||
for partition in partitions:
|
||||
for partition in range(self.getNumPartitions()):
|
||||
rows = self.context.runJob(self, lambda x: x, [partition])
|
||||
for row in rows:
|
||||
yield row
|
||||
|
|
|
@ -23,7 +23,7 @@ import math
|
|||
class RDDSamplerBase(object):
|
||||
|
||||
def __init__(self, withReplacement, seed=None):
|
||||
self._seed = seed if seed is not None else random.randint(0, sys.maxint)
|
||||
self._seed = seed if seed is not None else random.randint(0, sys.maxsize)
|
||||
self._withReplacement = withReplacement
|
||||
self._random = None
|
||||
|
||||
|
@ -31,7 +31,7 @@ class RDDSamplerBase(object):
|
|||
self._random = random.Random(self._seed ^ split)
|
||||
|
||||
# mixing because the initial seeds are close to each other
|
||||
for _ in xrange(10):
|
||||
for _ in range(10):
|
||||
self._random.randint(0, 1)
|
||||
|
||||
def getUniformSample(self):
|
||||
|
|
|
@ -49,16 +49,24 @@ which contains two batches of two objects:
|
|||
>>> sc.stop()
|
||||
"""
|
||||
|
||||
import cPickle
|
||||
from itertools import chain, izip, product
|
||||
import sys
|
||||
from itertools import chain, product
|
||||
import marshal
|
||||
import struct
|
||||
import sys
|
||||
import types
|
||||
import collections
|
||||
import zlib
|
||||
import itertools
|
||||
|
||||
if sys.version < '3':
|
||||
import cPickle as pickle
|
||||
protocol = 2
|
||||
from itertools import izip as zip
|
||||
else:
|
||||
import pickle
|
||||
protocol = 3
|
||||
xrange = range
|
||||
|
||||
from pyspark import cloudpickle
|
||||
|
||||
|
||||
|
@ -97,7 +105,7 @@ class Serializer(object):
|
|||
# subclasses should override __eq__ as appropriate.
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__)
|
||||
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
@ -212,10 +220,6 @@ class BatchedSerializer(Serializer):
|
|||
def _load_stream_without_unbatching(self, stream):
|
||||
return self.serializer.load_stream(stream)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, BatchedSerializer) and
|
||||
other.serializer == self.serializer and other.batchSize == self.batchSize)
|
||||
|
||||
def __repr__(self):
|
||||
return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
|
||||
|
||||
|
@ -233,14 +237,14 @@ class FlattenedValuesSerializer(BatchedSerializer):
|
|||
def _batched(self, iterator):
|
||||
n = self.batchSize
|
||||
for key, values in iterator:
|
||||
for i in xrange(0, len(values), n):
|
||||
for i in range(0, len(values), n):
|
||||
yield key, values[i:i + n]
|
||||
|
||||
def load_stream(self, stream):
|
||||
return self.serializer.load_stream(stream)
|
||||
|
||||
def __repr__(self):
|
||||
return "FlattenedValuesSerializer(%d)" % self.batchSize
|
||||
return "FlattenedValuesSerializer(%s, %d)" % (self.serializer, self.batchSize)
|
||||
|
||||
|
||||
class AutoBatchedSerializer(BatchedSerializer):
|
||||
|
@ -270,12 +274,8 @@ class AutoBatchedSerializer(BatchedSerializer):
|
|||
elif size > best * 10 and batch > 1:
|
||||
batch /= 2
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, AutoBatchedSerializer) and
|
||||
other.serializer == self.serializer and other.bestSize == self.bestSize)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoBatchedSerializer(%s)" % str(self.serializer)
|
||||
return "AutoBatchedSerializer(%s)" % self.serializer
|
||||
|
||||
|
||||
class CartesianDeserializer(FramedSerializer):
|
||||
|
@ -285,6 +285,7 @@ class CartesianDeserializer(FramedSerializer):
|
|||
"""
|
||||
|
||||
def __init__(self, key_ser, val_ser):
|
||||
FramedSerializer.__init__(self)
|
||||
self.key_ser = key_ser
|
||||
self.val_ser = val_ser
|
||||
|
||||
|
@ -293,7 +294,7 @@ class CartesianDeserializer(FramedSerializer):
|
|||
val_stream = self.val_ser._load_stream_without_unbatching(stream)
|
||||
key_is_batched = isinstance(self.key_ser, BatchedSerializer)
|
||||
val_is_batched = isinstance(self.val_ser, BatchedSerializer)
|
||||
for (keys, vals) in izip(key_stream, val_stream):
|
||||
for (keys, vals) in zip(key_stream, val_stream):
|
||||
keys = keys if key_is_batched else [keys]
|
||||
vals = vals if val_is_batched else [vals]
|
||||
yield (keys, vals)
|
||||
|
@ -303,10 +304,6 @@ class CartesianDeserializer(FramedSerializer):
|
|||
for pair in product(keys, vals):
|
||||
yield pair
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, CartesianDeserializer) and
|
||||
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
|
||||
|
||||
def __repr__(self):
|
||||
return "CartesianDeserializer(%s, %s)" % \
|
||||
(str(self.key_ser), str(self.val_ser))
|
||||
|
@ -318,22 +315,14 @@ class PairDeserializer(CartesianDeserializer):
|
|||
Deserializes the JavaRDD zip() of two PythonRDDs.
|
||||
"""
|
||||
|
||||
def __init__(self, key_ser, val_ser):
|
||||
self.key_ser = key_ser
|
||||
self.val_ser = val_ser
|
||||
|
||||
def load_stream(self, stream):
|
||||
for (keys, vals) in self.prepare_keys_values(stream):
|
||||
if len(keys) != len(vals):
|
||||
raise ValueError("Can not deserialize RDD with different number of items"
|
||||
" in pair: (%d, %d)" % (len(keys), len(vals)))
|
||||
for pair in izip(keys, vals):
|
||||
for pair in zip(keys, vals):
|
||||
yield pair
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, PairDeserializer) and
|
||||
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
|
||||
|
||||
def __repr__(self):
|
||||
return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))
|
||||
|
||||
|
@ -382,8 +371,8 @@ def _hijack_namedtuple():
|
|||
global _old_namedtuple # or it will put in closure
|
||||
|
||||
def _copy_func(f):
|
||||
return types.FunctionType(f.func_code, f.func_globals, f.func_name,
|
||||
f.func_defaults, f.func_closure)
|
||||
return types.FunctionType(f.__code__, f.__globals__, f.__name__,
|
||||
f.__defaults__, f.__closure__)
|
||||
|
||||
_old_namedtuple = _copy_func(collections.namedtuple)
|
||||
|
||||
|
@ -392,15 +381,15 @@ def _hijack_namedtuple():
|
|||
return _hack_namedtuple(cls)
|
||||
|
||||
# replace namedtuple with new one
|
||||
collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple
|
||||
collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple
|
||||
collections.namedtuple.func_code = namedtuple.func_code
|
||||
collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple
|
||||
collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple
|
||||
collections.namedtuple.__code__ = namedtuple.__code__
|
||||
collections.namedtuple.__hijack = 1
|
||||
|
||||
# hack the cls already generated by namedtuple
|
||||
# those created in other module can be pickled as normal,
|
||||
# so only hack those in __main__ module
|
||||
for n, o in sys.modules["__main__"].__dict__.iteritems():
|
||||
for n, o in sys.modules["__main__"].__dict__.items():
|
||||
if (type(o) is type and o.__base__ is tuple
|
||||
and hasattr(o, "_fields")
|
||||
and "__reduce__" not in o.__dict__):
|
||||
|
@ -413,7 +402,7 @@ _hijack_namedtuple()
|
|||
class PickleSerializer(FramedSerializer):
|
||||
|
||||
"""
|
||||
Serializes objects using Python's cPickle serializer:
|
||||
Serializes objects using Python's pickle serializer:
|
||||
|
||||
http://docs.python.org/2/library/pickle.html
|
||||
|
||||
|
@ -422,10 +411,14 @@ class PickleSerializer(FramedSerializer):
|
|||
"""
|
||||
|
||||
def dumps(self, obj):
|
||||
return cPickle.dumps(obj, 2)
|
||||
return pickle.dumps(obj, protocol)
|
||||
|
||||
def loads(self, obj):
|
||||
return cPickle.loads(obj)
|
||||
if sys.version >= '3':
|
||||
def loads(self, obj, encoding="bytes"):
|
||||
return pickle.loads(obj, encoding=encoding)
|
||||
else:
|
||||
def loads(self, obj, encoding=None):
|
||||
return pickle.loads(obj)
|
||||
|
||||
|
||||
class CloudPickleSerializer(PickleSerializer):
|
||||
|
@ -454,7 +447,7 @@ class MarshalSerializer(FramedSerializer):
|
|||
class AutoSerializer(FramedSerializer):
|
||||
|
||||
"""
|
||||
Choose marshal or cPickle as serialization protocol automatically
|
||||
Choose marshal or pickle as serialization protocol automatically
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -463,19 +456,19 @@ class AutoSerializer(FramedSerializer):
|
|||
|
||||
def dumps(self, obj):
|
||||
if self._type is not None:
|
||||
return 'P' + cPickle.dumps(obj, -1)
|
||||
return b'P' + pickle.dumps(obj, -1)
|
||||
try:
|
||||
return 'M' + marshal.dumps(obj)
|
||||
return b'M' + marshal.dumps(obj)
|
||||
except Exception:
|
||||
self._type = 'P'
|
||||
return 'P' + cPickle.dumps(obj, -1)
|
||||
self._type = b'P'
|
||||
return b'P' + pickle.dumps(obj, -1)
|
||||
|
||||
def loads(self, obj):
|
||||
_type = obj[0]
|
||||
if _type == 'M':
|
||||
if _type == b'M':
|
||||
return marshal.loads(obj[1:])
|
||||
elif _type == 'P':
|
||||
return cPickle.loads(obj[1:])
|
||||
elif _type == b'P':
|
||||
return pickle.loads(obj[1:])
|
||||
else:
|
||||
raise ValueError("invalid sevialization type: %s" % _type)
|
||||
|
||||
|
@ -495,8 +488,8 @@ class CompressedSerializer(FramedSerializer):
|
|||
def loads(self, obj):
|
||||
return self.serializer.loads(zlib.decompress(obj))
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, CompressedSerializer) and self.serializer == other.serializer
|
||||
def __repr__(self):
|
||||
return "CompressedSerializer(%s)" % self.serializer
|
||||
|
||||
|
||||
class UTF8Deserializer(Serializer):
|
||||
|
@ -505,7 +498,7 @@ class UTF8Deserializer(Serializer):
|
|||
Deserializes streams written by String.getBytes.
|
||||
"""
|
||||
|
||||
def __init__(self, use_unicode=False):
|
||||
def __init__(self, use_unicode=True):
|
||||
self.use_unicode = use_unicode
|
||||
|
||||
def loads(self, stream):
|
||||
|
@ -526,13 +519,13 @@ class UTF8Deserializer(Serializer):
|
|||
except EOFError:
|
||||
return
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode
|
||||
def __repr__(self):
|
||||
return "UTF8Deserializer(%s)" % self.use_unicode
|
||||
|
||||
|
||||
def read_long(stream):
|
||||
length = stream.read(8)
|
||||
if length == "":
|
||||
if not length:
|
||||
raise EOFError
|
||||
return struct.unpack("!q", length)[0]
|
||||
|
||||
|
@ -547,7 +540,7 @@ def pack_long(value):
|
|||
|
||||
def read_int(stream):
|
||||
length = stream.read(4)
|
||||
if length == "":
|
||||
if not length:
|
||||
raise EOFError
|
||||
return struct.unpack("!i", length)[0]
|
||||
|
||||
|
|
|
@ -21,13 +21,6 @@ An interactive shell.
|
|||
This file is designed to be launched as a PYTHONSTARTUP script.
|
||||
"""
|
||||
|
||||
import sys
|
||||
if sys.version_info[0] != 2:
|
||||
print("Error: Default Python used is Python%s" % sys.version_info.major)
|
||||
print("\tSet env variable PYSPARK_PYTHON to Python2 binary and re-run it.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import platform
|
||||
|
@ -53,9 +46,14 @@ atexit.register(lambda: sc.stop())
|
|||
try:
|
||||
# Try to access HiveConf, it will raise exception if Hive is not added
|
||||
sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
|
||||
sqlCtx = sqlContext = HiveContext(sc)
|
||||
sqlContext = HiveContext(sc)
|
||||
except py4j.protocol.Py4JError:
|
||||
sqlCtx = sqlContext = SQLContext(sc)
|
||||
sqlContext = SQLContext(sc)
|
||||
except TypeError:
|
||||
sqlContext = SQLContext(sc)
|
||||
|
||||
# for compatibility
|
||||
sqlCtx = sqlContext
|
||||
|
||||
print("""Welcome to
|
||||
____ __
|
||||
|
|
|
@ -78,8 +78,8 @@ def _get_local_dirs(sub):
|
|||
|
||||
|
||||
# global stats
|
||||
MemoryBytesSpilled = 0L
|
||||
DiskBytesSpilled = 0L
|
||||
MemoryBytesSpilled = 0
|
||||
DiskBytesSpilled = 0
|
||||
|
||||
|
||||
class Aggregator(object):
|
||||
|
@ -126,7 +126,7 @@ class Merger(object):
|
|||
""" Merge the combined items by mergeCombiner """
|
||||
raise NotImplementedError
|
||||
|
||||
def iteritems(self):
|
||||
def items(self):
|
||||
""" Return the merged items ad iterator """
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -156,9 +156,9 @@ class InMemoryMerger(Merger):
|
|||
for k, v in iterator:
|
||||
d[k] = comb(d[k], v) if k in d else v
|
||||
|
||||
def iteritems(self):
|
||||
""" Return the merged items as iterator """
|
||||
return self.data.iteritems()
|
||||
def items(self):
|
||||
""" Return the merged items ad iterator """
|
||||
return iter(self.data.items())
|
||||
|
||||
|
||||
def _compressed_serializer(self, serializer=None):
|
||||
|
@ -208,15 +208,15 @@ class ExternalMerger(Merger):
|
|||
>>> agg = SimpleAggregator(lambda x, y: x + y)
|
||||
>>> merger = ExternalMerger(agg, 10)
|
||||
>>> N = 10000
|
||||
>>> merger.mergeValues(zip(xrange(N), xrange(N)))
|
||||
>>> merger.mergeValues(zip(range(N), range(N)))
|
||||
>>> assert merger.spills > 0
|
||||
>>> sum(v for k,v in merger.iteritems())
|
||||
>>> sum(v for k,v in merger.items())
|
||||
49995000
|
||||
|
||||
>>> merger = ExternalMerger(agg, 10)
|
||||
>>> merger.mergeCombiners(zip(xrange(N), xrange(N)))
|
||||
>>> merger.mergeCombiners(zip(range(N), range(N)))
|
||||
>>> assert merger.spills > 0
|
||||
>>> sum(v for k,v in merger.iteritems())
|
||||
>>> sum(v for k,v in merger.items())
|
||||
49995000
|
||||
"""
|
||||
|
||||
|
@ -335,10 +335,10 @@ class ExternalMerger(Merger):
|
|||
# above limit at the first time.
|
||||
|
||||
# open all the files for writing
|
||||
streams = [open(os.path.join(path, str(i)), 'w')
|
||||
streams = [open(os.path.join(path, str(i)), 'wb')
|
||||
for i in range(self.partitions)]
|
||||
|
||||
for k, v in self.data.iteritems():
|
||||
for k, v in self.data.items():
|
||||
h = self._partition(k)
|
||||
# put one item in batch, make it compatible with load_stream
|
||||
# it will increase the memory if dump them in batch
|
||||
|
@ -354,9 +354,9 @@ class ExternalMerger(Merger):
|
|||
else:
|
||||
for i in range(self.partitions):
|
||||
p = os.path.join(path, str(i))
|
||||
with open(p, "w") as f:
|
||||
with open(p, "wb") as f:
|
||||
# dump items in batch
|
||||
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
|
||||
self.serializer.dump_stream(iter(self.pdata[i].items()), f)
|
||||
self.pdata[i].clear()
|
||||
DiskBytesSpilled += os.path.getsize(p)
|
||||
|
||||
|
@ -364,10 +364,10 @@ class ExternalMerger(Merger):
|
|||
gc.collect() # release the memory as much as possible
|
||||
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
|
||||
|
||||
def iteritems(self):
|
||||
def items(self):
|
||||
""" Return all merged items as iterator """
|
||||
if not self.pdata and not self.spills:
|
||||
return self.data.iteritems()
|
||||
return iter(self.data.items())
|
||||
return self._external_items()
|
||||
|
||||
def _external_items(self):
|
||||
|
@ -398,7 +398,8 @@ class ExternalMerger(Merger):
|
|||
path = self._get_spill_dir(j)
|
||||
p = os.path.join(path, str(index))
|
||||
# do not check memory during merging
|
||||
self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
|
||||
with open(p, "rb") as f:
|
||||
self.mergeCombiners(self.serializer.load_stream(f), 0)
|
||||
|
||||
# limit the total partitions
|
||||
if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
|
||||
|
@ -408,7 +409,7 @@ class ExternalMerger(Merger):
|
|||
gc.collect() # release the memory as much as possible
|
||||
return self._recursive_merged_items(index)
|
||||
|
||||
return self.data.iteritems()
|
||||
return self.data.items()
|
||||
|
||||
def _recursive_merged_items(self, index):
|
||||
"""
|
||||
|
@ -426,7 +427,8 @@ class ExternalMerger(Merger):
|
|||
for j in range(self.spills):
|
||||
path = self._get_spill_dir(j)
|
||||
p = os.path.join(path, str(index))
|
||||
m.mergeCombiners(self.serializer.load_stream(open(p)), 0)
|
||||
with open(p, 'rb') as f:
|
||||
m.mergeCombiners(self.serializer.load_stream(f), 0)
|
||||
|
||||
if get_used_memory() > limit:
|
||||
m._spill()
|
||||
|
@ -451,7 +453,7 @@ class ExternalSorter(object):
|
|||
|
||||
>>> sorter = ExternalSorter(1) # 1M
|
||||
>>> import random
|
||||
>>> l = range(1024)
|
||||
>>> l = list(range(1024))
|
||||
>>> random.shuffle(l)
|
||||
>>> sorted(l) == list(sorter.sorted(l))
|
||||
True
|
||||
|
@ -499,9 +501,16 @@ class ExternalSorter(object):
|
|||
# sort them inplace will save memory
|
||||
current_chunk.sort(key=key, reverse=reverse)
|
||||
path = self._get_path(len(chunks))
|
||||
with open(path, 'w') as f:
|
||||
with open(path, 'wb') as f:
|
||||
self.serializer.dump_stream(current_chunk, f)
|
||||
chunks.append(self.serializer.load_stream(open(path)))
|
||||
|
||||
def load(f):
|
||||
for v in self.serializer.load_stream(f):
|
||||
yield v
|
||||
# close the file explicit once we consume all the items
|
||||
# to avoid ResourceWarning in Python3
|
||||
f.close()
|
||||
chunks.append(load(open(path, 'rb')))
|
||||
current_chunk = []
|
||||
gc.collect()
|
||||
limit = self._next_limit()
|
||||
|
@ -527,7 +536,7 @@ class ExternalList(object):
|
|||
ExternalList can have many items which cannot be hold in memory in
|
||||
the same time.
|
||||
|
||||
>>> l = ExternalList(range(100))
|
||||
>>> l = ExternalList(list(range(100)))
|
||||
>>> len(l)
|
||||
100
|
||||
>>> l.append(10)
|
||||
|
@ -555,11 +564,11 @@ class ExternalList(object):
|
|||
def __getstate__(self):
|
||||
if self._file is not None:
|
||||
self._file.flush()
|
||||
f = os.fdopen(os.dup(self._file.fileno()))
|
||||
f.seek(0)
|
||||
serialized = f.read()
|
||||
with os.fdopen(os.dup(self._file.fileno()), "rb") as f:
|
||||
f.seek(0)
|
||||
serialized = f.read()
|
||||
else:
|
||||
serialized = ''
|
||||
serialized = b''
|
||||
return self.values, self.count, serialized
|
||||
|
||||
def __setstate__(self, item):
|
||||
|
@ -575,7 +584,7 @@ class ExternalList(object):
|
|||
if self._file is not None:
|
||||
self._file.flush()
|
||||
# read all items from disks first
|
||||
with os.fdopen(os.dup(self._file.fileno()), 'r') as f:
|
||||
with os.fdopen(os.dup(self._file.fileno()), 'rb') as f:
|
||||
f.seek(0)
|
||||
for v in self._ser.load_stream(f):
|
||||
yield v
|
||||
|
@ -598,11 +607,16 @@ class ExternalList(object):
|
|||
d = dirs[id(self) % len(dirs)]
|
||||
if not os.path.exists(d):
|
||||
os.makedirs(d)
|
||||
p = os.path.join(d, str(id))
|
||||
self._file = open(p, "w+", 65536)
|
||||
p = os.path.join(d, str(id(self)))
|
||||
self._file = open(p, "wb+", 65536)
|
||||
self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
|
||||
os.unlink(p)
|
||||
|
||||
def __del__(self):
|
||||
if self._file:
|
||||
self._file.close()
|
||||
self._file = None
|
||||
|
||||
def _spill(self):
|
||||
""" dump the values into disk """
|
||||
global MemoryBytesSpilled, DiskBytesSpilled
|
||||
|
@ -651,33 +665,28 @@ class GroupByKey(object):
|
|||
"""
|
||||
Group a sorted iterator as [(k1, it1), (k2, it2), ...]
|
||||
|
||||
>>> k = [i/3 for i in range(6)]
|
||||
>>> k = [i // 3 for i in range(6)]
|
||||
>>> v = [[i] for i in range(6)]
|
||||
>>> g = GroupByKey(iter(zip(k, v)))
|
||||
>>> g = GroupByKey(zip(k, v))
|
||||
>>> [(k, list(it)) for k, it in g]
|
||||
[(0, [0, 1, 2]), (1, [3, 4, 5])]
|
||||
"""
|
||||
|
||||
def __init__(self, iterator):
|
||||
self.iterator = iter(iterator)
|
||||
self.next_item = None
|
||||
self.iterator = iterator
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
key, value = self.next_item if self.next_item else next(self.iterator)
|
||||
values = ExternalListOfList([value])
|
||||
try:
|
||||
while True:
|
||||
k, v = next(self.iterator)
|
||||
if k != key:
|
||||
self.next_item = (k, v)
|
||||
break
|
||||
key, values = None, None
|
||||
for k, v in self.iterator:
|
||||
if values is not None and k == key:
|
||||
values.append(v)
|
||||
except StopIteration:
|
||||
self.next_item = None
|
||||
return key, values
|
||||
else:
|
||||
if values is not None:
|
||||
yield (key, values)
|
||||
key = k
|
||||
values = ExternalListOfList([v])
|
||||
if values is not None:
|
||||
yield (key, values)
|
||||
|
||||
|
||||
class ExternalGroupBy(ExternalMerger):
|
||||
|
@ -744,7 +753,7 @@ class ExternalGroupBy(ExternalMerger):
|
|||
# above limit at the first time.
|
||||
|
||||
# open all the files for writing
|
||||
streams = [open(os.path.join(path, str(i)), 'w')
|
||||
streams = [open(os.path.join(path, str(i)), 'wb')
|
||||
for i in range(self.partitions)]
|
||||
|
||||
# If the number of keys is small, then the overhead of sort is small
|
||||
|
@ -756,7 +765,7 @@ class ExternalGroupBy(ExternalMerger):
|
|||
h = self._partition(k)
|
||||
self.serializer.dump_stream([(k, self.data[k])], streams[h])
|
||||
else:
|
||||
for k, v in self.data.iteritems():
|
||||
for k, v in self.data.items():
|
||||
h = self._partition(k)
|
||||
self.serializer.dump_stream([(k, v)], streams[h])
|
||||
|
||||
|
@ -771,14 +780,14 @@ class ExternalGroupBy(ExternalMerger):
|
|||
else:
|
||||
for i in range(self.partitions):
|
||||
p = os.path.join(path, str(i))
|
||||
with open(p, "w") as f:
|
||||
with open(p, "wb") as f:
|
||||
# dump items in batch
|
||||
if self._sorted:
|
||||
# sort by key only (stable)
|
||||
sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0))
|
||||
sorted_items = sorted(self.pdata[i].items(), key=operator.itemgetter(0))
|
||||
self.serializer.dump_stream(sorted_items, f)
|
||||
else:
|
||||
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
|
||||
self.serializer.dump_stream(self.pdata[i].items(), f)
|
||||
self.pdata[i].clear()
|
||||
DiskBytesSpilled += os.path.getsize(p)
|
||||
|
||||
|
@ -792,7 +801,7 @@ class ExternalGroupBy(ExternalMerger):
|
|||
# if the memory can not hold all the partition,
|
||||
# then use sort based merge. Because of compression,
|
||||
# the data on disks will be much smaller than needed memory
|
||||
if (size >> 20) >= self.memory_limit / 10:
|
||||
if size >= self.memory_limit << 17: # * 1M / 8
|
||||
return self._merge_sorted_items(index)
|
||||
|
||||
self.data = {}
|
||||
|
@ -800,15 +809,18 @@ class ExternalGroupBy(ExternalMerger):
|
|||
path = self._get_spill_dir(j)
|
||||
p = os.path.join(path, str(index))
|
||||
# do not check memory during merging
|
||||
self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
|
||||
return self.data.iteritems()
|
||||
with open(p, "rb") as f:
|
||||
self.mergeCombiners(self.serializer.load_stream(f), 0)
|
||||
return self.data.items()
|
||||
|
||||
def _merge_sorted_items(self, index):
|
||||
""" load a partition from disk, then sort and group by key """
|
||||
def load_partition(j):
|
||||
path = self._get_spill_dir(j)
|
||||
p = os.path.join(path, str(index))
|
||||
return self.serializer.load_stream(open(p, 'r', 65536))
|
||||
with open(p, 'rb', 65536) as f:
|
||||
for v in self.serializer.load_stream(f):
|
||||
yield v
|
||||
|
||||
disk_items = [load_partition(j) for j in range(self.spills)]
|
||||
|
||||
|
|
|
@ -37,9 +37,22 @@ Important classes of Spark SQL and DataFrames:
|
|||
- L{types}
|
||||
List of data types available.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
# fix the module name conflict for Python 3+
|
||||
import sys
|
||||
from . import _types as types
|
||||
modname = __name__ + '.types'
|
||||
types.__name__ = modname
|
||||
# update the __module__ for all objects, make them picklable
|
||||
for v in types.__dict__.values():
|
||||
if hasattr(v, "__module__") and v.__module__.endswith('._types'):
|
||||
v.__module__ = modname
|
||||
sys.modules[modname] = types
|
||||
del modname, sys
|
||||
|
||||
from pyspark.sql.context import SQLContext, HiveContext
|
||||
from pyspark.sql.types import Row
|
||||
from pyspark.sql.context import SQLContext, HiveContext
|
||||
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import decimal
|
||||
import datetime
|
||||
import keyword
|
||||
|
@ -25,6 +26,9 @@ import weakref
|
|||
from array import array
|
||||
from operator import itemgetter
|
||||
|
||||
if sys.version >= "3":
|
||||
long = int
|
||||
unicode = str
|
||||
|
||||
__all__ = [
|
||||
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
|
||||
|
@ -410,7 +414,7 @@ class UserDefinedType(DataType):
|
|||
split = pyUDT.rfind(".")
|
||||
pyModule = pyUDT[:split]
|
||||
pyClass = pyUDT[split+1:]
|
||||
m = __import__(pyModule, globals(), locals(), [pyClass], -1)
|
||||
m = __import__(pyModule, globals(), locals(), [pyClass])
|
||||
UDT = getattr(m, pyClass)
|
||||
return UDT()
|
||||
|
||||
|
@ -419,10 +423,9 @@ class UserDefinedType(DataType):
|
|||
|
||||
|
||||
_all_primitive_types = dict((v.typeName(), v)
|
||||
for v in globals().itervalues()
|
||||
if type(v) is PrimitiveTypeSingleton and
|
||||
v.__base__ == PrimitiveType)
|
||||
|
||||
for v in list(globals().values())
|
||||
if (type(v) is type or type(v) is PrimitiveTypeSingleton)
|
||||
and v.__base__ == PrimitiveType)
|
||||
|
||||
_all_complex_types = dict((v.typeName(), v)
|
||||
for v in [ArrayType, MapType, StructType])
|
||||
|
@ -486,10 +489,10 @@ _FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
|
|||
|
||||
|
||||
def _parse_datatype_json_value(json_value):
|
||||
if type(json_value) is unicode:
|
||||
if not isinstance(json_value, dict):
|
||||
if json_value in _all_primitive_types.keys():
|
||||
return _all_primitive_types[json_value]()
|
||||
elif json_value == u'decimal':
|
||||
elif json_value == 'decimal':
|
||||
return DecimalType()
|
||||
elif _FIXED_DECIMAL.match(json_value):
|
||||
m = _FIXED_DECIMAL.match(json_value)
|
||||
|
@ -511,10 +514,8 @@ _type_mappings = {
|
|||
type(None): NullType,
|
||||
bool: BooleanType,
|
||||
int: LongType,
|
||||
long: LongType,
|
||||
float: DoubleType,
|
||||
str: StringType,
|
||||
unicode: StringType,
|
||||
bytearray: BinaryType,
|
||||
decimal.Decimal: DecimalType,
|
||||
datetime.date: DateType,
|
||||
|
@ -522,6 +523,12 @@ _type_mappings = {
|
|||
datetime.time: TimestampType,
|
||||
}
|
||||
|
||||
if sys.version < "3":
|
||||
_type_mappings.update({
|
||||
unicode: StringType,
|
||||
long: LongType,
|
||||
})
|
||||
|
||||
|
||||
def _infer_type(obj):
|
||||
"""Infer the DataType from obj
|
||||
|
@ -541,7 +548,7 @@ def _infer_type(obj):
|
|||
return dataType()
|
||||
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.iteritems():
|
||||
for key, value in obj.items():
|
||||
if key is not None and value is not None:
|
||||
return MapType(_infer_type(key), _infer_type(value), True)
|
||||
else:
|
||||
|
@ -565,10 +572,10 @@ def _infer_schema(row):
|
|||
items = sorted(row.items())
|
||||
|
||||
elif isinstance(row, (tuple, list)):
|
||||
if hasattr(row, "_fields"): # namedtuple
|
||||
items = zip(row._fields, tuple(row))
|
||||
elif hasattr(row, "__fields__"): # Row
|
||||
if hasattr(row, "__fields__"): # Row
|
||||
items = zip(row.__fields__, tuple(row))
|
||||
elif hasattr(row, "_fields"): # namedtuple
|
||||
items = zip(row._fields, tuple(row))
|
||||
else:
|
||||
names = ['_%d' % i for i in range(1, len(row) + 1)]
|
||||
items = zip(names, row)
|
||||
|
@ -647,7 +654,7 @@ def _python_to_sql_converter(dataType):
|
|||
if isinstance(obj, dict):
|
||||
return tuple(c(obj.get(n)) for n, c in zip(names, converters))
|
||||
elif isinstance(obj, tuple):
|
||||
if hasattr(obj, "_fields") or hasattr(obj, "__fields__"):
|
||||
if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
|
||||
return tuple(c(v) for c, v in zip(converters, obj))
|
||||
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
|
||||
d = dict(obj)
|
||||
|
@ -733,12 +740,12 @@ def _create_converter(dataType):
|
|||
|
||||
if isinstance(dataType, ArrayType):
|
||||
conv = _create_converter(dataType.elementType)
|
||||
return lambda row: map(conv, row)
|
||||
return lambda row: [conv(v) for v in row]
|
||||
|
||||
elif isinstance(dataType, MapType):
|
||||
kconv = _create_converter(dataType.keyType)
|
||||
vconv = _create_converter(dataType.valueType)
|
||||
return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems())
|
||||
return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
|
||||
|
||||
elif isinstance(dataType, NullType):
|
||||
return lambda x: None
|
||||
|
@ -881,7 +888,7 @@ def _infer_schema_type(obj, dataType):
|
|||
>>> _infer_schema_type(row, schema)
|
||||
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
|
||||
"""
|
||||
if dataType is NullType():
|
||||
if isinstance(dataType, NullType):
|
||||
return _infer_type(obj)
|
||||
|
||||
if not obj:
|
||||
|
@ -892,7 +899,7 @@ def _infer_schema_type(obj, dataType):
|
|||
return ArrayType(eType, True)
|
||||
|
||||
elif isinstance(dataType, MapType):
|
||||
k, v = obj.iteritems().next()
|
||||
k, v = next(iter(obj.items()))
|
||||
return MapType(_infer_schema_type(k, dataType.keyType),
|
||||
_infer_schema_type(v, dataType.valueType))
|
||||
|
||||
|
@ -935,7 +942,7 @@ def _verify_type(obj, dataType):
|
|||
>>> _verify_type(None, StructType([]))
|
||||
>>> _verify_type("", StringType())
|
||||
>>> _verify_type(0, LongType())
|
||||
>>> _verify_type(range(3), ArrayType(ShortType()))
|
||||
>>> _verify_type(list(range(3)), ArrayType(ShortType()))
|
||||
>>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
|
@ -976,7 +983,7 @@ def _verify_type(obj, dataType):
|
|||
_verify_type(i, dataType.elementType)
|
||||
|
||||
elif isinstance(dataType, MapType):
|
||||
for k, v in obj.iteritems():
|
||||
for k, v in obj.items():
|
||||
_verify_type(k, dataType.keyType)
|
||||
_verify_type(v, dataType.valueType)
|
||||
|
||||
|
@ -1213,6 +1220,8 @@ class Row(tuple):
|
|||
return self[idx]
|
||||
except IndexError:
|
||||
raise AttributeError(item)
|
||||
except ValueError:
|
||||
raise AttributeError(item)
|
||||
|
||||
def __reduce__(self):
|
||||
if hasattr(self, "__fields__"):
|
|
@ -15,14 +15,19 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
import json
|
||||
from itertools import imap
|
||||
|
||||
if sys.version >= '3':
|
||||
basestring = unicode = str
|
||||
else:
|
||||
from itertools import imap as map
|
||||
|
||||
from py4j.protocol import Py4JError
|
||||
from py4j.java_collections import MapConverter
|
||||
|
||||
from pyspark.rdd import RDD, _prepare_for_python_RDD
|
||||
from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
|
||||
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
|
||||
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
|
||||
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
|
||||
|
@ -62,31 +67,27 @@ class SQLContext(object):
|
|||
A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
|
||||
tables, execute SQL over tables, cache tables, and read parquet files.
|
||||
|
||||
When created, :class:`SQLContext` adds a method called ``toDF`` to :class:`RDD`,
|
||||
which could be used to convert an RDD into a DataFrame, it's a shorthand for
|
||||
:func:`SQLContext.createDataFrame`.
|
||||
|
||||
:param sparkContext: The :class:`SparkContext` backing this SQLContext.
|
||||
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new
|
||||
SQLContext in the JVM, instead we make all calls to this object.
|
||||
"""
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def __init__(self, sparkContext, sqlContext=None):
|
||||
"""Creates a new SQLContext.
|
||||
|
||||
>>> from datetime import datetime
|
||||
>>> sqlContext = SQLContext(sc)
|
||||
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
|
||||
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1,
|
||||
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
|
||||
... time=datetime(2014, 8, 1, 14, 1, 5))])
|
||||
>>> df = allTypes.toDF()
|
||||
>>> df.registerTempTable("allTypes")
|
||||
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
|
||||
... 'from allTypes where b and i > 0').collect()
|
||||
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
|
||||
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
|
||||
... x.row.a, x.list)).collect()
|
||||
[(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
|
||||
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
|
||||
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
|
||||
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
|
||||
"""
|
||||
self._sc = sparkContext
|
||||
self._jsc = self._sc._jsc
|
||||
|
@ -122,6 +123,7 @@ class SQLContext(object):
|
|||
"""Returns a :class:`UDFRegistration` for UDF registration."""
|
||||
return UDFRegistration(self)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def registerFunction(self, name, f, returnType=StringType()):
|
||||
"""Registers a lambda function as a UDF so it can be used in SQL statements.
|
||||
|
||||
|
@ -147,7 +149,7 @@ class SQLContext(object):
|
|||
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
|
||||
[Row(c0=4)]
|
||||
"""
|
||||
func = lambda _, it: imap(lambda x: f(*x), it)
|
||||
func = lambda _, it: map(lambda x: f(*x), it)
|
||||
ser = AutoBatchedSerializer(PickleSerializer())
|
||||
command = (func, None, ser, ser)
|
||||
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
|
||||
|
@ -185,6 +187,7 @@ class SQLContext(object):
|
|||
schema = rdd.map(_infer_schema).reduce(_merge_type)
|
||||
return schema
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def inferSchema(self, rdd, samplingRatio=None):
|
||||
"""::note: Deprecated in 1.3, use :func:`createDataFrame` instead.
|
||||
"""
|
||||
|
@ -195,6 +198,7 @@ class SQLContext(object):
|
|||
|
||||
return self.createDataFrame(rdd, None, samplingRatio)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def applySchema(self, rdd, schema):
|
||||
"""::note: Deprecated in 1.3, use :func:`createDataFrame` instead.
|
||||
"""
|
||||
|
@ -208,6 +212,7 @@ class SQLContext(object):
|
|||
|
||||
return self.createDataFrame(rdd, schema)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def createDataFrame(self, data, schema=None, samplingRatio=None):
|
||||
"""
|
||||
Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`,
|
||||
|
@ -380,6 +385,7 @@ class SQLContext(object):
|
|||
df = self._ssql_ctx.jsonFile(path, scala_datatype)
|
||||
return DataFrame(df, self)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
|
||||
"""Loads an RDD storing one JSON object per string as a :class:`DataFrame`.
|
||||
|
||||
|
@ -477,6 +483,7 @@ class SQLContext(object):
|
|||
joptions)
|
||||
return DataFrame(df, self)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def sql(self, sqlQuery):
|
||||
"""Returns a :class:`DataFrame` representing the result of the given query.
|
||||
|
||||
|
@ -497,6 +504,7 @@ class SQLContext(object):
|
|||
"""
|
||||
return DataFrame(self._ssql_ctx.table(tableName), self)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def tables(self, dbName=None):
|
||||
"""Returns a :class:`DataFrame` containing names of tables in the given database.
|
||||
|
||||
|
|
|
@ -16,14 +16,19 @@
|
|||
#
|
||||
|
||||
import sys
|
||||
import itertools
|
||||
import warnings
|
||||
import random
|
||||
|
||||
if sys.version >= '3':
|
||||
basestring = unicode = str
|
||||
long = int
|
||||
else:
|
||||
from itertools import imap as map
|
||||
|
||||
from py4j.java_collections import ListConverter, MapConverter
|
||||
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.rdd import RDD, _load_from_socket
|
||||
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
|
||||
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
from pyspark.traceback_utils import SCCallSiteSync
|
||||
|
@ -65,19 +70,20 @@ class DataFrame(object):
|
|||
self._sc = sql_ctx and sql_ctx._sc
|
||||
self.is_cached = False
|
||||
self._schema = None # initialized lazily
|
||||
self._lazy_rdd = None
|
||||
|
||||
@property
|
||||
def rdd(self):
|
||||
"""Returns the content as an :class:`pyspark.RDD` of :class:`Row`.
|
||||
"""
|
||||
if not hasattr(self, '_lazy_rdd'):
|
||||
if self._lazy_rdd is None:
|
||||
jrdd = self._jdf.javaToPython()
|
||||
rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
|
||||
schema = self.schema
|
||||
|
||||
def applySchema(it):
|
||||
cls = _create_cls(schema)
|
||||
return itertools.imap(cls, it)
|
||||
return map(cls, it)
|
||||
|
||||
self._lazy_rdd = rdd.mapPartitions(applySchema)
|
||||
|
||||
|
@ -89,13 +95,14 @@ class DataFrame(object):
|
|||
"""
|
||||
return DataFrameNaFunctions(self)
|
||||
|
||||
def toJSON(self, use_unicode=False):
|
||||
@ignore_unicode_prefix
|
||||
def toJSON(self, use_unicode=True):
|
||||
"""Converts a :class:`DataFrame` into a :class:`RDD` of string.
|
||||
|
||||
Each row is turned into a JSON document as one element in the returned RDD.
|
||||
|
||||
>>> df.toJSON().first()
|
||||
'{"age":2,"name":"Alice"}'
|
||||
u'{"age":2,"name":"Alice"}'
|
||||
"""
|
||||
rdd = self._jdf.toJSON()
|
||||
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
|
||||
|
@ -228,7 +235,7 @@ class DataFrame(object):
|
|||
|-- name: string (nullable = true)
|
||||
<BLANKLINE>
|
||||
"""
|
||||
print (self._jdf.schema().treeString())
|
||||
print(self._jdf.schema().treeString())
|
||||
|
||||
def explain(self, extended=False):
|
||||
"""Prints the (logical and physical) plans to the console for debugging purpose.
|
||||
|
@ -250,9 +257,9 @@ class DataFrame(object):
|
|||
== RDD ==
|
||||
"""
|
||||
if extended:
|
||||
print self._jdf.queryExecution().toString()
|
||||
print(self._jdf.queryExecution().toString())
|
||||
else:
|
||||
print self._jdf.queryExecution().executedPlan().toString()
|
||||
print(self._jdf.queryExecution().executedPlan().toString())
|
||||
|
||||
def isLocal(self):
|
||||
"""Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally
|
||||
|
@ -270,7 +277,7 @@ class DataFrame(object):
|
|||
2 Alice
|
||||
5 Bob
|
||||
"""
|
||||
print self._jdf.showString(n).encode('utf8', 'ignore')
|
||||
print(self._jdf.showString(n))
|
||||
|
||||
def __repr__(self):
|
||||
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
|
||||
|
@ -279,10 +286,11 @@ class DataFrame(object):
|
|||
"""Returns the number of rows in this :class:`DataFrame`.
|
||||
|
||||
>>> df.count()
|
||||
2L
|
||||
2
|
||||
"""
|
||||
return self._jdf.count()
|
||||
return int(self._jdf.count())
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def collect(self):
|
||||
"""Returns all the records as a list of :class:`Row`.
|
||||
|
||||
|
@ -295,6 +303,7 @@ class DataFrame(object):
|
|||
cls = _create_cls(self.schema)
|
||||
return [cls(r) for r in rs]
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def limit(self, num):
|
||||
"""Limits the result count to the number specified.
|
||||
|
||||
|
@ -306,6 +315,7 @@ class DataFrame(object):
|
|||
jdf = self._jdf.limit(num)
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def take(self, num):
|
||||
"""Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
|
||||
|
||||
|
@ -314,6 +324,7 @@ class DataFrame(object):
|
|||
"""
|
||||
return self.limit(num).collect()
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def map(self, f):
|
||||
""" Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`.
|
||||
|
||||
|
@ -324,6 +335,7 @@ class DataFrame(object):
|
|||
"""
|
||||
return self.rdd.map(f)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def flatMap(self, f):
|
||||
""" Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`,
|
||||
and then flattening the results.
|
||||
|
@ -353,7 +365,7 @@ class DataFrame(object):
|
|||
This is a shorthand for ``df.rdd.foreach()``.
|
||||
|
||||
>>> def f(person):
|
||||
... print person.name
|
||||
... print(person.name)
|
||||
>>> df.foreach(f)
|
||||
"""
|
||||
return self.rdd.foreach(f)
|
||||
|
@ -365,7 +377,7 @@ class DataFrame(object):
|
|||
|
||||
>>> def f(people):
|
||||
... for person in people:
|
||||
... print person.name
|
||||
... print(person.name)
|
||||
>>> df.foreachPartition(f)
|
||||
"""
|
||||
return self.rdd.foreachPartition(f)
|
||||
|
@ -412,7 +424,7 @@ class DataFrame(object):
|
|||
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
|
||||
|
||||
>>> df.distinct().count()
|
||||
2L
|
||||
2
|
||||
"""
|
||||
return DataFrame(self._jdf.distinct(), self.sql_ctx)
|
||||
|
||||
|
@ -420,10 +432,10 @@ class DataFrame(object):
|
|||
"""Returns a sampled subset of this :class:`DataFrame`.
|
||||
|
||||
>>> df.sample(False, 0.5, 97).count()
|
||||
1L
|
||||
1
|
||||
"""
|
||||
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
|
||||
seed = seed if seed is not None else random.randint(0, sys.maxint)
|
||||
seed = seed if seed is not None else random.randint(0, sys.maxsize)
|
||||
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
|
||||
return DataFrame(rdd, self.sql_ctx)
|
||||
|
||||
|
@ -437,6 +449,7 @@ class DataFrame(object):
|
|||
return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
|
||||
|
||||
@property
|
||||
@ignore_unicode_prefix
|
||||
def columns(self):
|
||||
"""Returns all column names as a list.
|
||||
|
||||
|
@ -445,6 +458,7 @@ class DataFrame(object):
|
|||
"""
|
||||
return [f.name for f in self.schema.fields]
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def join(self, other, joinExprs=None, joinType=None):
|
||||
"""Joins with another :class:`DataFrame`, using the given join expression.
|
||||
|
||||
|
@ -470,6 +484,7 @@ class DataFrame(object):
|
|||
jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def sort(self, *cols):
|
||||
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
|
||||
|
||||
|
@ -513,6 +528,7 @@ class DataFrame(object):
|
|||
jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def head(self, n=None):
|
||||
"""
|
||||
Returns the first ``n`` rows as a list of :class:`Row`,
|
||||
|
@ -528,6 +544,7 @@ class DataFrame(object):
|
|||
return rs[0] if rs else None
|
||||
return self.take(n)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def first(self):
|
||||
"""Returns the first row as a :class:`Row`.
|
||||
|
||||
|
@ -536,6 +553,7 @@ class DataFrame(object):
|
|||
"""
|
||||
return self.head()
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def __getitem__(self, item):
|
||||
"""Returns the column as a :class:`Column`.
|
||||
|
||||
|
@ -567,6 +585,7 @@ class DataFrame(object):
|
|||
jc = self._jdf.apply(name)
|
||||
return Column(jc)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def select(self, *cols):
|
||||
"""Projects a set of expressions and returns a new :class:`DataFrame`.
|
||||
|
||||
|
@ -598,6 +617,7 @@ class DataFrame(object):
|
|||
jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def filter(self, condition):
|
||||
"""Filters rows using the given condition.
|
||||
|
||||
|
@ -626,6 +646,7 @@ class DataFrame(object):
|
|||
|
||||
where = filter
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def groupBy(self, *cols):
|
||||
"""Groups the :class:`DataFrame` using the specified columns,
|
||||
so we can run aggregation on them. See :class:`GroupedData`
|
||||
|
@ -775,6 +796,7 @@ class DataFrame(object):
|
|||
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
|
||||
return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def withColumn(self, colName, col):
|
||||
"""Returns a new :class:`DataFrame` by adding a column.
|
||||
|
||||
|
@ -786,6 +808,7 @@ class DataFrame(object):
|
|||
"""
|
||||
return self.select('*', col.alias(colName))
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def withColumnRenamed(self, existing, new):
|
||||
"""REturns a new :class:`DataFrame` by renaming an existing column.
|
||||
|
||||
|
@ -852,6 +875,7 @@ class GroupedData(object):
|
|||
self._jdf = jdf
|
||||
self.sql_ctx = sql_ctx
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def agg(self, *exprs):
|
||||
"""Compute aggregates and returns the result as a :class:`DataFrame`.
|
||||
|
||||
|
@ -1041,11 +1065,13 @@ class Column(object):
|
|||
__sub__ = _bin_op("minus")
|
||||
__mul__ = _bin_op("multiply")
|
||||
__div__ = _bin_op("divide")
|
||||
__truediv__ = _bin_op("divide")
|
||||
__mod__ = _bin_op("mod")
|
||||
__radd__ = _bin_op("plus")
|
||||
__rsub__ = _reverse_op("minus")
|
||||
__rmul__ = _bin_op("multiply")
|
||||
__rdiv__ = _reverse_op("divide")
|
||||
__rtruediv__ = _reverse_op("divide")
|
||||
__rmod__ = _reverse_op("mod")
|
||||
|
||||
# logistic operators
|
||||
|
@ -1075,6 +1101,7 @@ class Column(object):
|
|||
startswith = _bin_op("startsWith")
|
||||
endswith = _bin_op("endsWith")
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def substr(self, startPos, length):
|
||||
"""
|
||||
Return a :class:`Column` which is a substring of the column
|
||||
|
@ -1097,6 +1124,7 @@ class Column(object):
|
|||
|
||||
__getslice__ = substr
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def inSet(self, *cols):
|
||||
""" A boolean expression that is evaluated to true if the value of this
|
||||
expression is contained by the evaluated values of the arguments.
|
||||
|
@ -1131,6 +1159,7 @@ class Column(object):
|
|||
"""
|
||||
return Column(getattr(self._jc, "as")(alias))
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def cast(self, dataType):
|
||||
""" Convert the column into type `dataType`
|
||||
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
"""
|
||||
A collections of builtin functions
|
||||
"""
|
||||
import sys
|
||||
|
||||
from itertools import imap
|
||||
if sys.version < "3":
|
||||
from itertools import imap as map
|
||||
|
||||
from py4j.java_collections import ListConverter
|
||||
|
||||
|
@ -116,7 +118,7 @@ class UserDefinedFunction(object):
|
|||
|
||||
def _create_judf(self):
|
||||
f = self.func # put it in closure `func`
|
||||
func = lambda _, it: imap(lambda x: f(*x), it)
|
||||
func = lambda _, it: map(lambda x: f(*x), it)
|
||||
ser = AutoBatchedSerializer(PickleSerializer())
|
||||
command = (func, None, ser, ser)
|
||||
sc = SparkContext._active_spark_context
|
||||
|
|
|
@ -157,13 +157,13 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
self.assertEqual(4, res[0])
|
||||
|
||||
def test_udf_with_array_type(self):
|
||||
d = [Row(l=range(3), d={"key": range(5)})]
|
||||
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
|
||||
rdd = self.sc.parallelize(d)
|
||||
self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
|
||||
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
|
||||
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
|
||||
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
|
||||
self.assertEqual(range(3), l1)
|
||||
self.assertEqual(list(range(3)), l1)
|
||||
self.assertEqual(1, l2)
|
||||
|
||||
def test_broadcast_in_udf(self):
|
||||
|
@ -266,7 +266,7 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
|
||||
def test_apply_schema(self):
|
||||
from datetime import date, datetime
|
||||
rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
|
||||
rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
|
||||
date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
|
||||
{"a": 1}, (2,), [1, 2, 3], None)])
|
||||
schema = StructType([
|
||||
|
@ -309,7 +309,7 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
def test_struct_in_map(self):
|
||||
d = [Row(m={Row(i=1): Row(s="")})]
|
||||
df = self.sc.parallelize(d).toDF()
|
||||
k, v = df.head().m.items()[0]
|
||||
k, v = list(df.head().m.items())[0]
|
||||
self.assertEqual(1, k.i)
|
||||
self.assertEqual("", v.s)
|
||||
|
||||
|
@ -554,6 +554,9 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
|
|||
except py4j.protocol.Py4JError:
|
||||
cls.sqlCtx = None
|
||||
return
|
||||
except TypeError:
|
||||
cls.sqlCtx = None
|
||||
return
|
||||
os.unlink(cls.tempdir.name)
|
||||
_scala_HiveContext =\
|
||||
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
|
||||
|
|
|
@ -31,7 +31,7 @@ except ImportError:
|
|||
class StatCounter(object):
|
||||
|
||||
def __init__(self, values=[]):
|
||||
self.n = 0L # Running count of our values
|
||||
self.n = 0 # Running count of our values
|
||||
self.mu = 0.0 # Running mean of our values
|
||||
self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2)
|
||||
self.maxValue = float("-inf")
|
||||
|
@ -87,7 +87,7 @@ class StatCounter(object):
|
|||
return copy.deepcopy(self)
|
||||
|
||||
def count(self):
|
||||
return self.n
|
||||
return int(self.n)
|
||||
|
||||
def mean(self):
|
||||
return self.mu
|
||||
|
|
|
@ -14,6 +14,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
@ -157,7 +160,7 @@ class StreamingContext(object):
|
|||
try:
|
||||
jssc = gw.jvm.JavaStreamingContext(checkpointPath)
|
||||
except Exception:
|
||||
print >>sys.stderr, "failed to load StreamingContext from checkpoint"
|
||||
print("failed to load StreamingContext from checkpoint", file=sys.stderr)
|
||||
raise
|
||||
|
||||
jsc = jssc.sparkContext()
|
||||
|
|
|
@ -15,11 +15,15 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from itertools import chain, ifilter, imap
|
||||
import sys
|
||||
import operator
|
||||
import time
|
||||
from itertools import chain
|
||||
from datetime import datetime
|
||||
|
||||
if sys.version < "3":
|
||||
from itertools import imap as map, ifilter as filter
|
||||
|
||||
from py4j.protocol import Py4JJavaError
|
||||
|
||||
from pyspark import RDD
|
||||
|
@ -76,7 +80,7 @@ class DStream(object):
|
|||
Return a new DStream containing only the elements that satisfy predicate.
|
||||
"""
|
||||
def func(iterator):
|
||||
return ifilter(f, iterator)
|
||||
return filter(f, iterator)
|
||||
return self.mapPartitions(func, True)
|
||||
|
||||
def flatMap(self, f, preservesPartitioning=False):
|
||||
|
@ -85,7 +89,7 @@ class DStream(object):
|
|||
this DStream, and then flattening the results
|
||||
"""
|
||||
def func(s, iterator):
|
||||
return chain.from_iterable(imap(f, iterator))
|
||||
return chain.from_iterable(map(f, iterator))
|
||||
return self.mapPartitionsWithIndex(func, preservesPartitioning)
|
||||
|
||||
def map(self, f, preservesPartitioning=False):
|
||||
|
@ -93,7 +97,7 @@ class DStream(object):
|
|||
Return a new DStream by applying a function to each element of DStream.
|
||||
"""
|
||||
def func(iterator):
|
||||
return imap(f, iterator)
|
||||
return map(f, iterator)
|
||||
return self.mapPartitions(func, preservesPartitioning)
|
||||
|
||||
def mapPartitions(self, f, preservesPartitioning=False):
|
||||
|
@ -150,7 +154,7 @@ class DStream(object):
|
|||
"""
|
||||
Apply a function to each RDD in this DStream.
|
||||
"""
|
||||
if func.func_code.co_argcount == 1:
|
||||
if func.__code__.co_argcount == 1:
|
||||
old_func = func
|
||||
func = lambda t, rdd: old_func(rdd)
|
||||
jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
|
||||
|
@ -165,14 +169,14 @@ class DStream(object):
|
|||
"""
|
||||
def takeAndPrint(time, rdd):
|
||||
taken = rdd.take(num + 1)
|
||||
print "-------------------------------------------"
|
||||
print "Time: %s" % time
|
||||
print "-------------------------------------------"
|
||||
print("-------------------------------------------")
|
||||
print("Time: %s" % time)
|
||||
print("-------------------------------------------")
|
||||
for record in taken[:num]:
|
||||
print record
|
||||
print(record)
|
||||
if len(taken) > num:
|
||||
print "..."
|
||||
print
|
||||
print("...")
|
||||
print()
|
||||
|
||||
self.foreachRDD(takeAndPrint)
|
||||
|
||||
|
@ -181,7 +185,7 @@ class DStream(object):
|
|||
Return a new DStream by applying a map function to the value of
|
||||
each key-value pairs in this DStream without changing the key.
|
||||
"""
|
||||
map_values_fn = lambda (k, v): (k, f(v))
|
||||
map_values_fn = lambda kv: (kv[0], f(kv[1]))
|
||||
return self.map(map_values_fn, preservesPartitioning=True)
|
||||
|
||||
def flatMapValues(self, f):
|
||||
|
@ -189,7 +193,7 @@ class DStream(object):
|
|||
Return a new DStream by applying a flatmap function to the value
|
||||
of each key-value pairs in this DStream without changing the key.
|
||||
"""
|
||||
flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
|
||||
flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1]))
|
||||
return self.flatMap(flat_map_fn, preservesPartitioning=True)
|
||||
|
||||
def glom(self):
|
||||
|
@ -286,10 +290,10 @@ class DStream(object):
|
|||
`func` can have one argument of `rdd`, or have two arguments of
|
||||
(`time`, `rdd`)
|
||||
"""
|
||||
if func.func_code.co_argcount == 1:
|
||||
if func.__code__.co_argcount == 1:
|
||||
oldfunc = func
|
||||
func = lambda t, rdd: oldfunc(rdd)
|
||||
assert func.func_code.co_argcount == 2, "func should take one or two arguments"
|
||||
assert func.__code__.co_argcount == 2, "func should take one or two arguments"
|
||||
return TransformedDStream(self, func)
|
||||
|
||||
def transformWith(self, func, other, keepSerializer=False):
|
||||
|
@ -300,10 +304,10 @@ class DStream(object):
|
|||
`func` can have two arguments of (`rdd_a`, `rdd_b`) or have three
|
||||
arguments of (`time`, `rdd_a`, `rdd_b`)
|
||||
"""
|
||||
if func.func_code.co_argcount == 2:
|
||||
if func.__code__.co_argcount == 2:
|
||||
oldfunc = func
|
||||
func = lambda t, a, b: oldfunc(a, b)
|
||||
assert func.func_code.co_argcount == 3, "func should take two or three arguments"
|
||||
assert func.__code__.co_argcount == 3, "func should take two or three arguments"
|
||||
jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer)
|
||||
dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
|
||||
other._jdstream.dstream(), jfunc)
|
||||
|
@ -460,7 +464,7 @@ class DStream(object):
|
|||
keyed = self.map(lambda x: (1, x))
|
||||
reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc,
|
||||
windowDuration, slideDuration, 1)
|
||||
return reduced.map(lambda (k, v): v)
|
||||
return reduced.map(lambda kv: kv[1])
|
||||
|
||||
def countByWindow(self, windowDuration, slideDuration):
|
||||
"""
|
||||
|
@ -489,7 +493,7 @@ class DStream(object):
|
|||
keyed = self.map(lambda x: (x, 1))
|
||||
counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub,
|
||||
windowDuration, slideDuration, numPartitions)
|
||||
return counted.filter(lambda (k, v): v > 0).count()
|
||||
return counted.filter(lambda kv: kv[1] > 0).count()
|
||||
|
||||
def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None):
|
||||
"""
|
||||
|
@ -548,7 +552,8 @@ class DStream(object):
|
|||
def invReduceFunc(t, a, b):
|
||||
b = b.reduceByKey(func, numPartitions)
|
||||
joined = a.leftOuterJoin(b, numPartitions)
|
||||
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
|
||||
return joined.mapValues(lambda kv: invFunc(kv[0], kv[1])
|
||||
if kv[1] is not None else kv[0])
|
||||
|
||||
jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer)
|
||||
if invReduceFunc:
|
||||
|
@ -579,9 +584,9 @@ class DStream(object):
|
|||
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
|
||||
else:
|
||||
g = a.cogroup(b.partitionBy(numPartitions), numPartitions)
|
||||
g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
|
||||
state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
|
||||
return state.filter(lambda (k, v): v is not None)
|
||||
g = g.mapValues(lambda ab: (list(ab[1]), list(ab[0])[0] if len(ab[0]) else None))
|
||||
state = g.mapValues(lambda vs_s: updateFunc(vs_s[0], vs_s[1]))
|
||||
return state.filter(lambda k_v: k_v[1] is not None)
|
||||
|
||||
jreduceFunc = TransformFunction(self._sc, reduceFunc,
|
||||
self._sc.serializer, self._jrdd_deserializer)
|
||||
|
|
|
@ -67,10 +67,10 @@ class KafkaUtils(object):
|
|||
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
|
||||
helper = helperClass.newInstance()
|
||||
jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel)
|
||||
except Py4JJavaError, e:
|
||||
except Py4JJavaError as e:
|
||||
# TODO: use --jar once it also work on driver
|
||||
if 'ClassNotFoundException' in str(e.java_exception):
|
||||
print """
|
||||
print("""
|
||||
________________________________________________________________________________________________
|
||||
|
||||
Spark Streaming's Kafka libraries not found in class path. Try one of the following.
|
||||
|
@ -88,8 +88,8 @@ ________________________________________________________________________________
|
|||
|
||||
________________________________________________________________________________________________
|
||||
|
||||
""" % (ssc.sparkContext.version, ssc.sparkContext.version)
|
||||
""" % (ssc.sparkContext.version, ssc.sparkContext.version))
|
||||
raise e
|
||||
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
|
||||
stream = DStream(jstream, ssc, ser)
|
||||
return stream.map(lambda (k, v): (keyDecoder(k), valueDecoder(v)))
|
||||
return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
|
||||
|
|
|
@ -22,6 +22,7 @@ import operator
|
|||
import unittest
|
||||
import tempfile
|
||||
import struct
|
||||
from functools import reduce
|
||||
|
||||
from py4j.java_collections import MapConverter
|
||||
|
||||
|
@ -51,7 +52,7 @@ class PySparkStreamingTestCase(unittest.TestCase):
|
|||
while len(result) < n and time.time() - start_time < self.timeout:
|
||||
time.sleep(0.01)
|
||||
if len(result) < n:
|
||||
print "timeout after", self.timeout
|
||||
print("timeout after", self.timeout)
|
||||
|
||||
def _take(self, dstream, n):
|
||||
"""
|
||||
|
@ -131,7 +132,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
|
|||
|
||||
def func(dstream):
|
||||
return dstream.map(str)
|
||||
expected = map(lambda x: map(str, x), input)
|
||||
expected = [list(map(str, x)) for x in input]
|
||||
self._test_func(input, func, expected)
|
||||
|
||||
def test_flatMap(self):
|
||||
|
@ -140,8 +141,8 @@ class BasicOperationTests(PySparkStreamingTestCase):
|
|||
|
||||
def func(dstream):
|
||||
return dstream.flatMap(lambda x: (x, x * 2))
|
||||
expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))),
|
||||
input)
|
||||
expected = [list(chain.from_iterable((map(lambda y: [y, y * 2], x))))
|
||||
for x in input]
|
||||
self._test_func(input, func, expected)
|
||||
|
||||
def test_filter(self):
|
||||
|
@ -150,7 +151,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
|
|||
|
||||
def func(dstream):
|
||||
return dstream.filter(lambda x: x % 2 == 0)
|
||||
expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input)
|
||||
expected = [[y for y in x if y % 2 == 0] for x in input]
|
||||
self._test_func(input, func, expected)
|
||||
|
||||
def test_count(self):
|
||||
|
@ -159,7 +160,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
|
|||
|
||||
def func(dstream):
|
||||
return dstream.count()
|
||||
expected = map(lambda x: [len(x)], input)
|
||||
expected = [[len(x)] for x in input]
|
||||
self._test_func(input, func, expected)
|
||||
|
||||
def test_reduce(self):
|
||||
|
@ -168,7 +169,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
|
|||
|
||||
def func(dstream):
|
||||
return dstream.reduce(operator.add)
|
||||
expected = map(lambda x: [reduce(operator.add, x)], input)
|
||||
expected = [[reduce(operator.add, x)] for x in input]
|
||||
self._test_func(input, func, expected)
|
||||
|
||||
def test_reduceByKey(self):
|
||||
|
@ -185,27 +186,27 @@ class BasicOperationTests(PySparkStreamingTestCase):
|
|||
def test_mapValues(self):
|
||||
"""Basic operation test for DStream.mapValues."""
|
||||
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
|
||||
[("", 4), (1, 1), (2, 2), (3, 3)],
|
||||
[(0, 4), (1, 1), (2, 2), (3, 3)],
|
||||
[(1, 1), (2, 1), (3, 1), (4, 1)]]
|
||||
|
||||
def func(dstream):
|
||||
return dstream.mapValues(lambda x: x + 10)
|
||||
expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)],
|
||||
[("", 14), (1, 11), (2, 12), (3, 13)],
|
||||
[(0, 14), (1, 11), (2, 12), (3, 13)],
|
||||
[(1, 11), (2, 11), (3, 11), (4, 11)]]
|
||||
self._test_func(input, func, expected, sort=True)
|
||||
|
||||
def test_flatMapValues(self):
|
||||
"""Basic operation test for DStream.flatMapValues."""
|
||||
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
|
||||
[("", 4), (1, 1), (2, 1), (3, 1)],
|
||||
[(0, 4), (1, 1), (2, 1), (3, 1)],
|
||||
[(1, 1), (2, 1), (3, 1), (4, 1)]]
|
||||
|
||||
def func(dstream):
|
||||
return dstream.flatMapValues(lambda x: (x, x + 10))
|
||||
expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12),
|
||||
("c", 1), ("c", 11), ("d", 1), ("d", 11)],
|
||||
[("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
|
||||
[(0, 4), (0, 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
|
||||
[(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]]
|
||||
self._test_func(input, func, expected)
|
||||
|
||||
|
@ -233,7 +234,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
|
|||
|
||||
def test_countByValue(self):
|
||||
"""Basic operation test for DStream.countByValue."""
|
||||
input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]]
|
||||
input = [list(range(1, 5)) * 2, list(range(5, 7)) + list(range(5, 9)), ["a", "a", "b", ""]]
|
||||
|
||||
def func(dstream):
|
||||
return dstream.countByValue()
|
||||
|
@ -285,7 +286,7 @@ class BasicOperationTests(PySparkStreamingTestCase):
|
|||
def func(d1, d2):
|
||||
return d1.union(d2)
|
||||
|
||||
expected = [range(6), range(6), range(6)]
|
||||
expected = [list(range(6)), list(range(6)), list(range(6))]
|
||||
self._test_func(input1, func, expected, input2=input2)
|
||||
|
||||
def test_cogroup(self):
|
||||
|
@ -424,7 +425,7 @@ class StreamingContextTests(PySparkStreamingTestCase):
|
|||
duration = 0.1
|
||||
|
||||
def _add_input_stream(self):
|
||||
inputs = map(lambda x: range(1, x), range(101))
|
||||
inputs = [range(1, x) for x in range(101)]
|
||||
stream = self.ssc.queueStream(inputs)
|
||||
self._collect(stream, 1, block=False)
|
||||
|
||||
|
@ -441,7 +442,7 @@ class StreamingContextTests(PySparkStreamingTestCase):
|
|||
self.ssc.stop()
|
||||
|
||||
def test_queue_stream(self):
|
||||
input = [range(i + 1) for i in range(3)]
|
||||
input = [list(range(i + 1)) for i in range(3)]
|
||||
dstream = self.ssc.queueStream(input)
|
||||
result = self._collect(dstream, 3)
|
||||
self.assertEqual(input, result)
|
||||
|
@ -457,13 +458,13 @@ class StreamingContextTests(PySparkStreamingTestCase):
|
|||
with open(os.path.join(d, name), "w") as f:
|
||||
f.writelines(["%d\n" % i for i in range(10)])
|
||||
self.wait_for(result, 2)
|
||||
self.assertEqual([range(10), range(10)], result)
|
||||
self.assertEqual([list(range(10)), list(range(10))], result)
|
||||
|
||||
def test_binary_records_stream(self):
|
||||
d = tempfile.mkdtemp()
|
||||
self.ssc = StreamingContext(self.sc, self.duration)
|
||||
dstream = self.ssc.binaryRecordsStream(d, 10).map(
|
||||
lambda v: struct.unpack("10b", str(v)))
|
||||
lambda v: struct.unpack("10b", bytes(v)))
|
||||
result = self._collect(dstream, 2, block=False)
|
||||
self.ssc.start()
|
||||
for name in ('a', 'b'):
|
||||
|
@ -471,10 +472,10 @@ class StreamingContextTests(PySparkStreamingTestCase):
|
|||
with open(os.path.join(d, name), "wb") as f:
|
||||
f.write(bytearray(range(10)))
|
||||
self.wait_for(result, 2)
|
||||
self.assertEqual([range(10), range(10)], map(lambda v: list(v[0]), result))
|
||||
self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result])
|
||||
|
||||
def test_union(self):
|
||||
input = [range(i + 1) for i in range(3)]
|
||||
input = [list(range(i + 1)) for i in range(3)]
|
||||
dstream = self.ssc.queueStream(input)
|
||||
dstream2 = self.ssc.queueStream(input)
|
||||
dstream3 = self.ssc.union(dstream, dstream2)
|
||||
|
|
|
@ -91,9 +91,9 @@ class TransformFunctionSerializer(object):
|
|||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
def loads(self, bytes):
|
||||
def loads(self, data):
|
||||
try:
|
||||
f, deserializers = self.serializer.loads(str(bytes))
|
||||
f, deserializers = self.serializer.loads(bytes(data))
|
||||
return TransformFunction(self.ctx, f, *deserializers)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
@ -116,7 +116,7 @@ def rddToFileName(prefix, suffix, timestamp):
|
|||
"""
|
||||
if isinstance(timestamp, datetime):
|
||||
seconds = time.mktime(timestamp.timetuple())
|
||||
timestamp = long(seconds * 1000) + timestamp.microsecond / 1000
|
||||
timestamp = int(seconds * 1000) + timestamp.microsecond // 1000
|
||||
if suffix is None:
|
||||
return prefix + "-" + str(timestamp)
|
||||
else:
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
Unit tests for PySpark; additional tests are implemented as doctests in
|
||||
individual modules.
|
||||
"""
|
||||
|
||||
from array import array
|
||||
from fileinput import input
|
||||
from glob import glob
|
||||
import os
|
||||
import re
|
||||
|
@ -45,6 +45,9 @@ if sys.version_info[:2] <= (2, 6):
|
|||
sys.exit(1)
|
||||
else:
|
||||
import unittest
|
||||
if sys.version_info[0] >= 3:
|
||||
xrange = range
|
||||
basestring = str
|
||||
|
||||
|
||||
from pyspark.conf import SparkConf
|
||||
|
@ -52,7 +55,9 @@ from pyspark.context import SparkContext
|
|||
from pyspark.rdd import RDD
|
||||
from pyspark.files import SparkFiles
|
||||
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
|
||||
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer
|
||||
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \
|
||||
PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \
|
||||
FlattenedValuesSerializer
|
||||
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
|
||||
from pyspark import shuffle
|
||||
from pyspark.profiler import BasicProfiler
|
||||
|
@ -81,7 +86,7 @@ class MergerTests(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.N = 1 << 12
|
||||
self.l = [i for i in xrange(self.N)]
|
||||
self.data = zip(self.l, self.l)
|
||||
self.data = list(zip(self.l, self.l))
|
||||
self.agg = Aggregator(lambda x: [x],
|
||||
lambda x, y: x.append(y) or x,
|
||||
lambda x, y: x.extend(y) or x)
|
||||
|
@ -89,45 +94,45 @@ class MergerTests(unittest.TestCase):
|
|||
def test_in_memory(self):
|
||||
m = InMemoryMerger(self.agg)
|
||||
m.mergeValues(self.data)
|
||||
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
|
||||
self.assertEqual(sum(sum(v) for k, v in m.items()),
|
||||
sum(xrange(self.N)))
|
||||
|
||||
m = InMemoryMerger(self.agg)
|
||||
m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
|
||||
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
|
||||
m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data))
|
||||
self.assertEqual(sum(sum(v) for k, v in m.items()),
|
||||
sum(xrange(self.N)))
|
||||
|
||||
def test_small_dataset(self):
|
||||
m = ExternalMerger(self.agg, 1000)
|
||||
m.mergeValues(self.data)
|
||||
self.assertEqual(m.spills, 0)
|
||||
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
|
||||
self.assertEqual(sum(sum(v) for k, v in m.items()),
|
||||
sum(xrange(self.N)))
|
||||
|
||||
m = ExternalMerger(self.agg, 1000)
|
||||
m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
|
||||
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data))
|
||||
self.assertEqual(m.spills, 0)
|
||||
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
|
||||
self.assertEqual(sum(sum(v) for k, v in m.items()),
|
||||
sum(xrange(self.N)))
|
||||
|
||||
def test_medium_dataset(self):
|
||||
m = ExternalMerger(self.agg, 30)
|
||||
m = ExternalMerger(self.agg, 20)
|
||||
m.mergeValues(self.data)
|
||||
self.assertTrue(m.spills >= 1)
|
||||
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
|
||||
self.assertEqual(sum(sum(v) for k, v in m.items()),
|
||||
sum(xrange(self.N)))
|
||||
|
||||
m = ExternalMerger(self.agg, 10)
|
||||
m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3))
|
||||
m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3))
|
||||
self.assertTrue(m.spills >= 1)
|
||||
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
|
||||
self.assertEqual(sum(sum(v) for k, v in m.items()),
|
||||
sum(xrange(self.N)) * 3)
|
||||
|
||||
def test_huge_dataset(self):
|
||||
m = ExternalMerger(self.agg, 10, partitions=3)
|
||||
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
|
||||
m = ExternalMerger(self.agg, 5, partitions=3)
|
||||
m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10))
|
||||
self.assertTrue(m.spills >= 1)
|
||||
self.assertEqual(sum(len(v) for k, v in m.iteritems()),
|
||||
self.assertEqual(sum(len(v) for k, v in m.items()),
|
||||
self.N * 10)
|
||||
m._cleanup()
|
||||
|
||||
|
@ -144,55 +149,55 @@ class MergerTests(unittest.TestCase):
|
|||
self.assertEqual(1, len(list(gen_gs(1))))
|
||||
self.assertEqual(2, len(list(gen_gs(2))))
|
||||
self.assertEqual(100, len(list(gen_gs(100))))
|
||||
self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)])
|
||||
self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100)))
|
||||
self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)])
|
||||
self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100)))
|
||||
|
||||
for k, vs in gen_gs(50002, 10000):
|
||||
self.assertEqual(k, len(vs))
|
||||
self.assertEqual(range(k), list(vs))
|
||||
self.assertEqual(list(range(k)), list(vs))
|
||||
|
||||
ser = PickleSerializer()
|
||||
l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
|
||||
for k, vs in l:
|
||||
self.assertEqual(k, len(vs))
|
||||
self.assertEqual(range(k), list(vs))
|
||||
self.assertEqual(list(range(k)), list(vs))
|
||||
|
||||
|
||||
class SorterTests(unittest.TestCase):
|
||||
def test_in_memory_sort(self):
|
||||
l = range(1024)
|
||||
l = list(range(1024))
|
||||
random.shuffle(l)
|
||||
sorter = ExternalSorter(1024)
|
||||
self.assertEquals(sorted(l), list(sorter.sorted(l)))
|
||||
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
|
||||
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
|
||||
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
|
||||
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
|
||||
self.assertEqual(sorted(l), list(sorter.sorted(l)))
|
||||
self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
|
||||
self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
|
||||
self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
|
||||
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
|
||||
|
||||
def test_external_sort(self):
|
||||
l = range(1024)
|
||||
l = list(range(1024))
|
||||
random.shuffle(l)
|
||||
sorter = ExternalSorter(1)
|
||||
self.assertEquals(sorted(l), list(sorter.sorted(l)))
|
||||
self.assertEqual(sorted(l), list(sorter.sorted(l)))
|
||||
self.assertGreater(shuffle.DiskBytesSpilled, 0)
|
||||
last = shuffle.DiskBytesSpilled
|
||||
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
|
||||
self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
|
||||
self.assertGreater(shuffle.DiskBytesSpilled, last)
|
||||
last = shuffle.DiskBytesSpilled
|
||||
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
|
||||
self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
|
||||
self.assertGreater(shuffle.DiskBytesSpilled, last)
|
||||
last = shuffle.DiskBytesSpilled
|
||||
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
|
||||
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
|
||||
self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
|
||||
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
|
||||
self.assertGreater(shuffle.DiskBytesSpilled, last)
|
||||
|
||||
def test_external_sort_in_rdd(self):
|
||||
conf = SparkConf().set("spark.python.worker.memory", "1m")
|
||||
sc = SparkContext(conf=conf)
|
||||
l = range(10240)
|
||||
l = list(range(10240))
|
||||
random.shuffle(l)
|
||||
rdd = sc.parallelize(l, 10)
|
||||
self.assertEquals(sorted(l), rdd.sortBy(lambda x: x).collect())
|
||||
rdd = sc.parallelize(l, 2)
|
||||
self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
|
||||
sc.stop()
|
||||
|
||||
|
||||
|
@ -200,11 +205,11 @@ class SerializationTestCase(unittest.TestCase):
|
|||
|
||||
def test_namedtuple(self):
|
||||
from collections import namedtuple
|
||||
from cPickle import dumps, loads
|
||||
from pickle import dumps, loads
|
||||
P = namedtuple("P", "x y")
|
||||
p1 = P(1, 3)
|
||||
p2 = loads(dumps(p1, 2))
|
||||
self.assertEquals(p1, p2)
|
||||
self.assertEqual(p1, p2)
|
||||
|
||||
def test_itemgetter(self):
|
||||
from operator import itemgetter
|
||||
|
@ -246,7 +251,7 @@ class SerializationTestCase(unittest.TestCase):
|
|||
ser = CloudPickleSerializer()
|
||||
out1 = sys.stderr
|
||||
out2 = ser.loads(ser.dumps(out1))
|
||||
self.assertEquals(out1, out2)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_func_globals(self):
|
||||
|
||||
|
@ -263,19 +268,36 @@ class SerializationTestCase(unittest.TestCase):
|
|||
def foo():
|
||||
sys.exit(0)
|
||||
|
||||
self.assertTrue("exit" in foo.func_code.co_names)
|
||||
self.assertTrue("exit" in foo.__code__.co_names)
|
||||
ser.dumps(foo)
|
||||
|
||||
def test_compressed_serializer(self):
|
||||
ser = CompressedSerializer(PickleSerializer())
|
||||
from StringIO import StringIO
|
||||
try:
|
||||
from StringIO import StringIO
|
||||
except ImportError:
|
||||
from io import BytesIO as StringIO
|
||||
io = StringIO()
|
||||
ser.dump_stream(["abc", u"123", range(5)], io)
|
||||
io.seek(0)
|
||||
self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
|
||||
ser.dump_stream(range(1000), io)
|
||||
io.seek(0)
|
||||
self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io)))
|
||||
self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io)))
|
||||
io.close()
|
||||
|
||||
def test_hash_serializer(self):
|
||||
hash(NoOpSerializer())
|
||||
hash(UTF8Deserializer())
|
||||
hash(PickleSerializer())
|
||||
hash(MarshalSerializer())
|
||||
hash(AutoSerializer())
|
||||
hash(BatchedSerializer(PickleSerializer()))
|
||||
hash(AutoBatchedSerializer(MarshalSerializer()))
|
||||
hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer()))
|
||||
hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer()))
|
||||
hash(CompressedSerializer(PickleSerializer()))
|
||||
hash(FlattenedValuesSerializer(PickleSerializer()))
|
||||
|
||||
|
||||
class PySparkTestCase(unittest.TestCase):
|
||||
|
@ -340,7 +362,7 @@ class CheckpointTests(ReusedPySparkTestCase):
|
|||
self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
|
||||
recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
|
||||
flatMappedRDD._jrdd_deserializer)
|
||||
self.assertEquals([1, 2, 3, 4], recovered.collect())
|
||||
self.assertEqual([1, 2, 3, 4], recovered.collect())
|
||||
|
||||
|
||||
class AddFileTests(PySparkTestCase):
|
||||
|
@ -356,8 +378,7 @@ class AddFileTests(PySparkTestCase):
|
|||
def func(x):
|
||||
from userlibrary import UserClass
|
||||
return UserClass().hello()
|
||||
self.assertRaises(Exception,
|
||||
self.sc.parallelize(range(2)).map(func).first)
|
||||
self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first)
|
||||
log4j.LogManager.getRootLogger().setLevel(old_level)
|
||||
|
||||
# Add the file, so the job should now succeed:
|
||||
|
@ -372,7 +393,7 @@ class AddFileTests(PySparkTestCase):
|
|||
download_path = SparkFiles.get("hello.txt")
|
||||
self.assertNotEqual(path, download_path)
|
||||
with open(download_path) as test_file:
|
||||
self.assertEquals("Hello World!\n", test_file.readline())
|
||||
self.assertEqual("Hello World!\n", test_file.readline())
|
||||
|
||||
def test_add_py_file_locally(self):
|
||||
# To ensure that we're actually testing addPyFile's effects, check that
|
||||
|
@ -381,7 +402,7 @@ class AddFileTests(PySparkTestCase):
|
|||
from userlibrary import UserClass
|
||||
self.assertRaises(ImportError, func)
|
||||
path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
|
||||
self.sc.addFile(path)
|
||||
self.sc.addPyFile(path)
|
||||
from userlibrary import UserClass
|
||||
self.assertEqual("Hello World!", UserClass().hello())
|
||||
|
||||
|
@ -391,7 +412,7 @@ class AddFileTests(PySparkTestCase):
|
|||
def func():
|
||||
from userlib import UserClass
|
||||
self.assertRaises(ImportError, func)
|
||||
path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg")
|
||||
path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip")
|
||||
self.sc.addPyFile(path)
|
||||
from userlib import UserClass
|
||||
self.assertEqual("Hello World from inside a package!", UserClass().hello())
|
||||
|
@ -427,8 +448,9 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
tempFile = tempfile.NamedTemporaryFile(delete=True)
|
||||
tempFile.close()
|
||||
data.saveAsTextFile(tempFile.name)
|
||||
raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*")))
|
||||
self.assertEqual(x, unicode(raw_contents.strip(), "utf-8"))
|
||||
raw_contents = b''.join(open(p, 'rb').read()
|
||||
for p in glob(tempFile.name + "/part-0000*"))
|
||||
self.assertEqual(x, raw_contents.strip().decode("utf-8"))
|
||||
|
||||
def test_save_as_textfile_with_utf8(self):
|
||||
x = u"\u00A1Hola, mundo!"
|
||||
|
@ -436,19 +458,20 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
tempFile = tempfile.NamedTemporaryFile(delete=True)
|
||||
tempFile.close()
|
||||
data.saveAsTextFile(tempFile.name)
|
||||
raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*")))
|
||||
self.assertEqual(x, unicode(raw_contents.strip(), "utf-8"))
|
||||
raw_contents = b''.join(open(p, 'rb').read()
|
||||
for p in glob(tempFile.name + "/part-0000*"))
|
||||
self.assertEqual(x, raw_contents.strip().decode('utf8'))
|
||||
|
||||
def test_transforming_cartesian_result(self):
|
||||
# Regression test for SPARK-1034
|
||||
rdd1 = self.sc.parallelize([1, 2])
|
||||
rdd2 = self.sc.parallelize([3, 4])
|
||||
cart = rdd1.cartesian(rdd2)
|
||||
result = cart.map(lambda (x, y): x + y).collect()
|
||||
result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect()
|
||||
|
||||
def test_transforming_pickle_file(self):
|
||||
# Regression test for SPARK-2601
|
||||
data = self.sc.parallelize(["Hello", "World!"])
|
||||
data = self.sc.parallelize([u"Hello", u"World!"])
|
||||
tempFile = tempfile.NamedTemporaryFile(delete=True)
|
||||
tempFile.close()
|
||||
data.saveAsPickleFile(tempFile.name)
|
||||
|
@ -461,13 +484,13 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
a = self.sc.textFile(path)
|
||||
result = a.cartesian(a).collect()
|
||||
(x, y) = result[0]
|
||||
self.assertEqual("Hello World!", x.strip())
|
||||
self.assertEqual("Hello World!", y.strip())
|
||||
self.assertEqual(u"Hello World!", x.strip())
|
||||
self.assertEqual(u"Hello World!", y.strip())
|
||||
|
||||
def test_deleting_input_files(self):
|
||||
# Regression test for SPARK-1025
|
||||
tempFile = tempfile.NamedTemporaryFile(delete=False)
|
||||
tempFile.write("Hello World!")
|
||||
tempFile.write(b"Hello World!")
|
||||
tempFile.close()
|
||||
data = self.sc.textFile(tempFile.name)
|
||||
filtered_data = data.filter(lambda x: True)
|
||||
|
@ -510,21 +533,21 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
jon = Person(1, "Jon", "Doe")
|
||||
jane = Person(2, "Jane", "Doe")
|
||||
theDoes = self.sc.parallelize([jon, jane])
|
||||
self.assertEquals([jon, jane], theDoes.collect())
|
||||
self.assertEqual([jon, jane], theDoes.collect())
|
||||
|
||||
def test_large_broadcast(self):
|
||||
N = 100000
|
||||
data = [[float(i) for i in range(300)] for i in range(N)]
|
||||
bdata = self.sc.broadcast(data) # 270MB
|
||||
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
|
||||
self.assertEquals(N, m)
|
||||
self.assertEqual(N, m)
|
||||
|
||||
def test_multiple_broadcasts(self):
|
||||
N = 1 << 21
|
||||
b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
|
||||
r = range(1 << 15)
|
||||
r = list(range(1 << 15))
|
||||
random.shuffle(r)
|
||||
s = str(r)
|
||||
s = str(r).encode()
|
||||
checksum = hashlib.md5(s).hexdigest()
|
||||
b2 = self.sc.broadcast(s)
|
||||
r = list(set(self.sc.parallelize(range(10), 10).map(
|
||||
|
@ -535,7 +558,7 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
self.assertEqual(checksum, csum)
|
||||
|
||||
random.shuffle(r)
|
||||
s = str(r)
|
||||
s = str(r).encode()
|
||||
checksum = hashlib.md5(s).hexdigest()
|
||||
b2 = self.sc.broadcast(s)
|
||||
r = list(set(self.sc.parallelize(range(10), 10).map(
|
||||
|
@ -549,7 +572,7 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
N = 1000000
|
||||
data = [float(i) for i in xrange(N)]
|
||||
rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
|
||||
self.assertEquals(N, rdd.first())
|
||||
self.assertEqual(N, rdd.first())
|
||||
# regression test for SPARK-6886
|
||||
self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
|
||||
|
||||
|
@ -590,15 +613,15 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
# same total number of items, but different distributions
|
||||
a = self.sc.parallelize([2, 3], 2).flatMap(range)
|
||||
b = self.sc.parallelize([3, 2], 2).flatMap(range)
|
||||
self.assertEquals(a.count(), b.count())
|
||||
self.assertEqual(a.count(), b.count())
|
||||
self.assertRaises(Exception, lambda: a.zip(b).count())
|
||||
|
||||
def test_count_approx_distinct(self):
|
||||
rdd = self.sc.parallelize(range(1000))
|
||||
self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050)
|
||||
self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050)
|
||||
self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050)
|
||||
self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050)
|
||||
self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
|
||||
self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
|
||||
self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
|
||||
self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050)
|
||||
|
||||
rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
|
||||
self.assertTrue(18 < rdd.countApproxDistinct() < 22)
|
||||
|
@ -612,59 +635,59 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
def test_histogram(self):
|
||||
# empty
|
||||
rdd = self.sc.parallelize([])
|
||||
self.assertEquals([0], rdd.histogram([0, 10])[1])
|
||||
self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1])
|
||||
self.assertEqual([0], rdd.histogram([0, 10])[1])
|
||||
self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
|
||||
self.assertRaises(ValueError, lambda: rdd.histogram(1))
|
||||
|
||||
# out of range
|
||||
rdd = self.sc.parallelize([10.01, -0.01])
|
||||
self.assertEquals([0], rdd.histogram([0, 10])[1])
|
||||
self.assertEquals([0, 0], rdd.histogram((0, 4, 10))[1])
|
||||
self.assertEqual([0], rdd.histogram([0, 10])[1])
|
||||
self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1])
|
||||
|
||||
# in range with one bucket
|
||||
rdd = self.sc.parallelize(range(1, 5))
|
||||
self.assertEquals([4], rdd.histogram([0, 10])[1])
|
||||
self.assertEquals([3, 1], rdd.histogram([0, 4, 10])[1])
|
||||
self.assertEqual([4], rdd.histogram([0, 10])[1])
|
||||
self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1])
|
||||
|
||||
# in range with one bucket exact match
|
||||
self.assertEquals([4], rdd.histogram([1, 4])[1])
|
||||
self.assertEqual([4], rdd.histogram([1, 4])[1])
|
||||
|
||||
# out of range with two buckets
|
||||
rdd = self.sc.parallelize([10.01, -0.01])
|
||||
self.assertEquals([0, 0], rdd.histogram([0, 5, 10])[1])
|
||||
self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1])
|
||||
|
||||
# out of range with two uneven buckets
|
||||
rdd = self.sc.parallelize([10.01, -0.01])
|
||||
self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1])
|
||||
self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
|
||||
|
||||
# in range with two buckets
|
||||
rdd = self.sc.parallelize([1, 2, 3, 5, 6])
|
||||
self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1])
|
||||
self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
|
||||
|
||||
# in range with two bucket and None
|
||||
rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')])
|
||||
self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1])
|
||||
self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
|
||||
|
||||
# in range with two uneven buckets
|
||||
rdd = self.sc.parallelize([1, 2, 3, 5, 6])
|
||||
self.assertEquals([3, 2], rdd.histogram([0, 5, 11])[1])
|
||||
self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1])
|
||||
|
||||
# mixed range with two uneven buckets
|
||||
rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01])
|
||||
self.assertEquals([4, 3], rdd.histogram([0, 5, 11])[1])
|
||||
self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1])
|
||||
|
||||
# mixed range with four uneven buckets
|
||||
rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1])
|
||||
self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
|
||||
self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
|
||||
|
||||
# mixed range with uneven buckets and NaN
|
||||
rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0,
|
||||
199.0, 200.0, 200.1, None, float('nan')])
|
||||
self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
|
||||
self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
|
||||
|
||||
# out of range with infinite buckets
|
||||
rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")])
|
||||
self.assertEquals([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1])
|
||||
self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1])
|
||||
|
||||
# invalid buckets
|
||||
self.assertRaises(ValueError, lambda: rdd.histogram([]))
|
||||
|
@ -674,25 +697,25 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
|
||||
# without buckets
|
||||
rdd = self.sc.parallelize(range(1, 5))
|
||||
self.assertEquals(([1, 4], [4]), rdd.histogram(1))
|
||||
self.assertEqual(([1, 4], [4]), rdd.histogram(1))
|
||||
|
||||
# without buckets single element
|
||||
rdd = self.sc.parallelize([1])
|
||||
self.assertEquals(([1, 1], [1]), rdd.histogram(1))
|
||||
self.assertEqual(([1, 1], [1]), rdd.histogram(1))
|
||||
|
||||
# without bucket no range
|
||||
rdd = self.sc.parallelize([1] * 4)
|
||||
self.assertEquals(([1, 1], [4]), rdd.histogram(1))
|
||||
self.assertEqual(([1, 1], [4]), rdd.histogram(1))
|
||||
|
||||
# without buckets basic two
|
||||
rdd = self.sc.parallelize(range(1, 5))
|
||||
self.assertEquals(([1, 2.5, 4], [2, 2]), rdd.histogram(2))
|
||||
self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2))
|
||||
|
||||
# without buckets with more requested than elements
|
||||
rdd = self.sc.parallelize([1, 2])
|
||||
buckets = [1 + 0.2 * i for i in range(6)]
|
||||
hist = [1, 0, 0, 0, 1]
|
||||
self.assertEquals((buckets, hist), rdd.histogram(5))
|
||||
self.assertEqual((buckets, hist), rdd.histogram(5))
|
||||
|
||||
# invalid RDDs
|
||||
rdd = self.sc.parallelize([1, float('inf')])
|
||||
|
@ -702,15 +725,8 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
|
||||
# string
|
||||
rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2)
|
||||
self.assertEquals([2, 2], rdd.histogram(["a", "b", "c"])[1])
|
||||
self.assertEquals((["ab", "ef"], [5]), rdd.histogram(1))
|
||||
self.assertRaises(TypeError, lambda: rdd.histogram(2))
|
||||
|
||||
# mixed RDD
|
||||
rdd = self.sc.parallelize([1, 4, "ab", "ac", "b"], 2)
|
||||
self.assertEquals([1, 1], rdd.histogram([0, 4, 10])[1])
|
||||
self.assertEquals([2, 1], rdd.histogram(["a", "b", "c"])[1])
|
||||
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
|
||||
self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1])
|
||||
self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1))
|
||||
self.assertRaises(TypeError, lambda: rdd.histogram(2))
|
||||
|
||||
def test_repartitionAndSortWithinPartitions(self):
|
||||
|
@ -718,31 +734,31 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
|
||||
repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
|
||||
partitions = repartitioned.glom().collect()
|
||||
self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)])
|
||||
self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])
|
||||
self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)])
|
||||
self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)])
|
||||
|
||||
def test_distinct(self):
|
||||
rdd = self.sc.parallelize((1, 2, 3)*10, 10)
|
||||
self.assertEquals(rdd.getNumPartitions(), 10)
|
||||
self.assertEquals(rdd.distinct().count(), 3)
|
||||
self.assertEqual(rdd.getNumPartitions(), 10)
|
||||
self.assertEqual(rdd.distinct().count(), 3)
|
||||
result = rdd.distinct(5)
|
||||
self.assertEquals(result.getNumPartitions(), 5)
|
||||
self.assertEquals(result.count(), 3)
|
||||
self.assertEqual(result.getNumPartitions(), 5)
|
||||
self.assertEqual(result.count(), 3)
|
||||
|
||||
def test_external_group_by_key(self):
|
||||
self.sc._conf.set("spark.python.worker.memory", "5m")
|
||||
self.sc._conf.set("spark.python.worker.memory", "1m")
|
||||
N = 200001
|
||||
kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
|
||||
gkv = kv.groupByKey().cache()
|
||||
self.assertEqual(3, gkv.count())
|
||||
filtered = gkv.filter(lambda (k, vs): k == 1)
|
||||
filtered = gkv.filter(lambda kv: kv[0] == 1)
|
||||
self.assertEqual(1, filtered.count())
|
||||
self.assertEqual([(1, N/3)], filtered.mapValues(len).collect())
|
||||
self.assertEqual([(N/3, N/3)],
|
||||
self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect())
|
||||
self.assertEqual([(N // 3, N // 3)],
|
||||
filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
|
||||
result = filtered.collect()[0][1]
|
||||
self.assertEqual(N/3, len(result))
|
||||
self.assertTrue(isinstance(result.data, shuffle.ExternalList))
|
||||
self.assertEqual(N // 3, len(result))
|
||||
self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList))
|
||||
|
||||
def test_sort_on_empty_rdd(self):
|
||||
self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
|
||||
|
@ -767,7 +783,7 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
rdd = RDD(jrdd, self.sc, UTF8Deserializer())
|
||||
self.assertEqual([u"a", None, u"b"], rdd.collect())
|
||||
rdd = RDD(jrdd, self.sc, NoOpSerializer())
|
||||
self.assertEqual(["a", None, "b"], rdd.collect())
|
||||
self.assertEqual([b"a", None, b"b"], rdd.collect())
|
||||
|
||||
def test_multiple_python_java_RDD_conversions(self):
|
||||
# Regression test for SPARK-5361
|
||||
|
@ -813,14 +829,14 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
self.sc.setJobGroup("test3", "test", True)
|
||||
d = sorted(parted.cogroup(parted).collect())
|
||||
self.assertEqual(10, len(d))
|
||||
self.assertEqual([[0], [0]], map(list, d[0][1]))
|
||||
self.assertEqual([[0], [0]], list(map(list, d[0][1])))
|
||||
jobId = tracker.getJobIdsForGroup("test3")[0]
|
||||
self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
|
||||
|
||||
self.sc.setJobGroup("test4", "test", True)
|
||||
d = sorted(parted.cogroup(rdd).collect())
|
||||
self.assertEqual(10, len(d))
|
||||
self.assertEqual([[0], [0]], map(list, d[0][1]))
|
||||
self.assertEqual([[0], [0]], list(map(list, d[0][1])))
|
||||
jobId = tracker.getJobIdsForGroup("test4")[0]
|
||||
self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
|
||||
|
||||
|
@ -906,6 +922,7 @@ class InputFormatTests(ReusedPySparkTestCase):
|
|||
ReusedPySparkTestCase.tearDownClass()
|
||||
shutil.rmtree(cls.tempdir.name)
|
||||
|
||||
@unittest.skipIf(sys.version >= "3", "serialize array of byte")
|
||||
def test_sequencefiles(self):
|
||||
basepath = self.tempdir.name
|
||||
ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/",
|
||||
|
@ -954,15 +971,16 @@ class InputFormatTests(ReusedPySparkTestCase):
|
|||
en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
|
||||
self.assertEqual(nulls, en)
|
||||
|
||||
maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/",
|
||||
"org.apache.hadoop.io.IntWritable",
|
||||
"org.apache.hadoop.io.MapWritable").collect())
|
||||
maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/",
|
||||
"org.apache.hadoop.io.IntWritable",
|
||||
"org.apache.hadoop.io.MapWritable").collect()
|
||||
em = [(1, {}),
|
||||
(1, {3.0: u'bb'}),
|
||||
(2, {1.0: u'aa'}),
|
||||
(2, {1.0: u'cc'}),
|
||||
(3, {2.0: u'dd'})]
|
||||
self.assertEqual(maps, em)
|
||||
for v in maps:
|
||||
self.assertTrue(v in em)
|
||||
|
||||
# arrays get pickled to tuples by default
|
||||
tuples = sorted(self.sc.sequenceFile(
|
||||
|
@ -1089,8 +1107,8 @@ class InputFormatTests(ReusedPySparkTestCase):
|
|||
def test_binary_files(self):
|
||||
path = os.path.join(self.tempdir.name, "binaryfiles")
|
||||
os.mkdir(path)
|
||||
data = "short binary data"
|
||||
with open(os.path.join(path, "part-0000"), 'w') as f:
|
||||
data = b"short binary data"
|
||||
with open(os.path.join(path, "part-0000"), 'wb') as f:
|
||||
f.write(data)
|
||||
[(p, d)] = self.sc.binaryFiles(path).collect()
|
||||
self.assertTrue(p.endswith("part-0000"))
|
||||
|
@ -1103,7 +1121,7 @@ class InputFormatTests(ReusedPySparkTestCase):
|
|||
for i in range(100):
|
||||
f.write('%04d' % i)
|
||||
result = self.sc.binaryRecords(path, 4).map(int).collect()
|
||||
self.assertEqual(range(100), result)
|
||||
self.assertEqual(list(range(100)), result)
|
||||
|
||||
|
||||
class OutputFormatTests(ReusedPySparkTestCase):
|
||||
|
@ -1115,6 +1133,7 @@ class OutputFormatTests(ReusedPySparkTestCase):
|
|||
def tearDown(self):
|
||||
shutil.rmtree(self.tempdir.name, ignore_errors=True)
|
||||
|
||||
@unittest.skipIf(sys.version >= "3", "serialize array of byte")
|
||||
def test_sequencefiles(self):
|
||||
basepath = self.tempdir.name
|
||||
ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
|
||||
|
@ -1155,8 +1174,9 @@ class OutputFormatTests(ReusedPySparkTestCase):
|
|||
(2, {1.0: u'cc'}),
|
||||
(3, {2.0: u'dd'})]
|
||||
self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/")
|
||||
maps = sorted(self.sc.sequenceFile(basepath + "/sfmap/").collect())
|
||||
self.assertEqual(maps, em)
|
||||
maps = self.sc.sequenceFile(basepath + "/sfmap/").collect()
|
||||
for v in maps:
|
||||
self.assertTrue(v, em)
|
||||
|
||||
def test_oldhadoop(self):
|
||||
basepath = self.tempdir.name
|
||||
|
@ -1168,12 +1188,13 @@ class OutputFormatTests(ReusedPySparkTestCase):
|
|||
"org.apache.hadoop.mapred.SequenceFileOutputFormat",
|
||||
"org.apache.hadoop.io.IntWritable",
|
||||
"org.apache.hadoop.io.MapWritable")
|
||||
result = sorted(self.sc.hadoopFile(
|
||||
result = self.sc.hadoopFile(
|
||||
basepath + "/oldhadoop/",
|
||||
"org.apache.hadoop.mapred.SequenceFileInputFormat",
|
||||
"org.apache.hadoop.io.IntWritable",
|
||||
"org.apache.hadoop.io.MapWritable").collect())
|
||||
self.assertEqual(result, dict_data)
|
||||
"org.apache.hadoop.io.MapWritable").collect()
|
||||
for v in result:
|
||||
self.assertTrue(v, dict_data)
|
||||
|
||||
conf = {
|
||||
"mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat",
|
||||
|
@ -1183,12 +1204,13 @@ class OutputFormatTests(ReusedPySparkTestCase):
|
|||
}
|
||||
self.sc.parallelize(dict_data).saveAsHadoopDataset(conf)
|
||||
input_conf = {"mapred.input.dir": basepath + "/olddataset/"}
|
||||
old_dataset = sorted(self.sc.hadoopRDD(
|
||||
result = self.sc.hadoopRDD(
|
||||
"org.apache.hadoop.mapred.SequenceFileInputFormat",
|
||||
"org.apache.hadoop.io.IntWritable",
|
||||
"org.apache.hadoop.io.MapWritable",
|
||||
conf=input_conf).collect())
|
||||
self.assertEqual(old_dataset, dict_data)
|
||||
conf=input_conf).collect()
|
||||
for v in result:
|
||||
self.assertTrue(v, dict_data)
|
||||
|
||||
def test_newhadoop(self):
|
||||
basepath = self.tempdir.name
|
||||
|
@ -1223,6 +1245,7 @@ class OutputFormatTests(ReusedPySparkTestCase):
|
|||
conf=input_conf).collect())
|
||||
self.assertEqual(new_dataset, data)
|
||||
|
||||
@unittest.skipIf(sys.version >= "3", "serialize of array")
|
||||
def test_newhadoop_with_array(self):
|
||||
basepath = self.tempdir.name
|
||||
# use custom ArrayWritable types and converters to handle arrays
|
||||
|
@ -1303,7 +1326,7 @@ class OutputFormatTests(ReusedPySparkTestCase):
|
|||
basepath = self.tempdir.name
|
||||
x = range(1, 5)
|
||||
y = range(1001, 1005)
|
||||
data = zip(x, y)
|
||||
data = list(zip(x, y))
|
||||
rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y))
|
||||
rdd.saveAsSequenceFile(basepath + "/reserialize/sequence")
|
||||
result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect())
|
||||
|
@ -1354,7 +1377,7 @@ class DaemonTests(unittest.TestCase):
|
|||
sock = socket(AF_INET, SOCK_STREAM)
|
||||
sock.connect(('127.0.0.1', port))
|
||||
# send a split index of -1 to shutdown the worker
|
||||
sock.send("\xFF\xFF\xFF\xFF")
|
||||
sock.send(b"\xFF\xFF\xFF\xFF")
|
||||
sock.close()
|
||||
return True
|
||||
|
||||
|
@ -1395,7 +1418,6 @@ class DaemonTests(unittest.TestCase):
|
|||
|
||||
|
||||
class WorkerTests(PySparkTestCase):
|
||||
|
||||
def test_cancel_task(self):
|
||||
temp = tempfile.NamedTemporaryFile(delete=True)
|
||||
temp.close()
|
||||
|
@ -1410,7 +1432,7 @@ class WorkerTests(PySparkTestCase):
|
|||
|
||||
# start job in background thread
|
||||
def run():
|
||||
self.sc.parallelize(range(1)).foreach(sleep)
|
||||
self.sc.parallelize(range(1), 1).foreach(sleep)
|
||||
import threading
|
||||
t = threading.Thread(target=run)
|
||||
t.daemon = True
|
||||
|
@ -1419,7 +1441,8 @@ class WorkerTests(PySparkTestCase):
|
|||
daemon_pid, worker_pid = 0, 0
|
||||
while True:
|
||||
if os.path.exists(path):
|
||||
data = open(path).read().split(' ')
|
||||
with open(path) as f:
|
||||
data = f.read().split(' ')
|
||||
daemon_pid, worker_pid = map(int, data)
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
@ -1455,7 +1478,7 @@ class WorkerTests(PySparkTestCase):
|
|||
|
||||
def test_after_jvm_exception(self):
|
||||
tempFile = tempfile.NamedTemporaryFile(delete=False)
|
||||
tempFile.write("Hello World!")
|
||||
tempFile.write(b"Hello World!")
|
||||
tempFile.close()
|
||||
data = self.sc.textFile(tempFile.name, 1)
|
||||
filtered_data = data.filter(lambda x: True)
|
||||
|
@ -1577,12 +1600,12 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
|from pyspark import SparkContext
|
||||
|
|
||||
|sc = SparkContext()
|
||||
|print sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()
|
||||
|print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect())
|
||||
""")
|
||||
proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE)
|
||||
out, err = proc.communicate()
|
||||
self.assertEqual(0, proc.returncode)
|
||||
self.assertIn("[2, 4, 6]", out)
|
||||
self.assertIn("[2, 4, 6]", out.decode('utf-8'))
|
||||
|
||||
def test_script_with_local_functions(self):
|
||||
"""Submit and test a single script file calling a global function"""
|
||||
|
@ -1593,12 +1616,12 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
| return x * 3
|
||||
|
|
||||
|sc = SparkContext()
|
||||
|print sc.parallelize([1, 2, 3]).map(foo).collect()
|
||||
|print(sc.parallelize([1, 2, 3]).map(foo).collect())
|
||||
""")
|
||||
proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE)
|
||||
out, err = proc.communicate()
|
||||
self.assertEqual(0, proc.returncode)
|
||||
self.assertIn("[3, 6, 9]", out)
|
||||
self.assertIn("[3, 6, 9]", out.decode('utf-8'))
|
||||
|
||||
def test_module_dependency(self):
|
||||
"""Submit and test a script with a dependency on another module"""
|
||||
|
@ -1607,7 +1630,7 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
|from mylib import myfunc
|
||||
|
|
||||
|sc = SparkContext()
|
||||
|print sc.parallelize([1, 2, 3]).map(myfunc).collect()
|
||||
|print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
|
||||
""")
|
||||
zip = self.createFileInZip("mylib.py", """
|
||||
|def myfunc(x):
|
||||
|
@ -1617,7 +1640,7 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
stdout=subprocess.PIPE)
|
||||
out, err = proc.communicate()
|
||||
self.assertEqual(0, proc.returncode)
|
||||
self.assertIn("[2, 3, 4]", out)
|
||||
self.assertIn("[2, 3, 4]", out.decode('utf-8'))
|
||||
|
||||
def test_module_dependency_on_cluster(self):
|
||||
"""Submit and test a script with a dependency on another module on a cluster"""
|
||||
|
@ -1626,7 +1649,7 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
|from mylib import myfunc
|
||||
|
|
||||
|sc = SparkContext()
|
||||
|print sc.parallelize([1, 2, 3]).map(myfunc).collect()
|
||||
|print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
|
||||
""")
|
||||
zip = self.createFileInZip("mylib.py", """
|
||||
|def myfunc(x):
|
||||
|
@ -1637,7 +1660,7 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
stdout=subprocess.PIPE)
|
||||
out, err = proc.communicate()
|
||||
self.assertEqual(0, proc.returncode)
|
||||
self.assertIn("[2, 3, 4]", out)
|
||||
self.assertIn("[2, 3, 4]", out.decode('utf-8'))
|
||||
|
||||
def test_package_dependency(self):
|
||||
"""Submit and test a script with a dependency on a Spark Package"""
|
||||
|
@ -1646,14 +1669,14 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
|from mylib import myfunc
|
||||
|
|
||||
|sc = SparkContext()
|
||||
|print sc.parallelize([1, 2, 3]).map(myfunc).collect()
|
||||
|print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
|
||||
""")
|
||||
self.create_spark_package("a:mylib:0.1")
|
||||
proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
|
||||
"file:" + self.programDir, script], stdout=subprocess.PIPE)
|
||||
out, err = proc.communicate()
|
||||
self.assertEqual(0, proc.returncode)
|
||||
self.assertIn("[2, 3, 4]", out)
|
||||
self.assertIn("[2, 3, 4]", out.decode('utf-8'))
|
||||
|
||||
def test_package_dependency_on_cluster(self):
|
||||
"""Submit and test a script with a dependency on a Spark Package on a cluster"""
|
||||
|
@ -1662,7 +1685,7 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
|from mylib import myfunc
|
||||
|
|
||||
|sc = SparkContext()
|
||||
|print sc.parallelize([1, 2, 3]).map(myfunc).collect()
|
||||
|print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
|
||||
""")
|
||||
self.create_spark_package("a:mylib:0.1")
|
||||
proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
|
||||
|
@ -1670,7 +1693,7 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
"local-cluster[1,1,512]", script], stdout=subprocess.PIPE)
|
||||
out, err = proc.communicate()
|
||||
self.assertEqual(0, proc.returncode)
|
||||
self.assertIn("[2, 3, 4]", out)
|
||||
self.assertIn("[2, 3, 4]", out.decode('utf-8'))
|
||||
|
||||
def test_single_script_on_cluster(self):
|
||||
"""Submit and test a single script on a cluster"""
|
||||
|
@ -1681,7 +1704,7 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
| return x * 2
|
||||
|
|
||||
|sc = SparkContext()
|
||||
|print sc.parallelize([1, 2, 3]).map(foo).collect()
|
||||
|print(sc.parallelize([1, 2, 3]).map(foo).collect())
|
||||
""")
|
||||
# this will fail if you have different spark.executor.memory
|
||||
# in conf/spark-defaults.conf
|
||||
|
@ -1690,7 +1713,7 @@ class SparkSubmitTests(unittest.TestCase):
|
|||
stdout=subprocess.PIPE)
|
||||
out, err = proc.communicate()
|
||||
self.assertEqual(0, proc.returncode)
|
||||
self.assertIn("[2, 4, 6]", out)
|
||||
self.assertIn("[2, 4, 6]", out.decode('utf-8'))
|
||||
|
||||
|
||||
class ContextTests(unittest.TestCase):
|
||||
|
@ -1765,7 +1788,7 @@ class SciPyTests(PySparkTestCase):
|
|||
def test_serialize(self):
|
||||
from scipy.special import gammaln
|
||||
x = range(1, 5)
|
||||
expected = map(gammaln, x)
|
||||
expected = list(map(gammaln, x))
|
||||
observed = self.sc.parallelize(x).map(gammaln).collect()
|
||||
self.assertEqual(expected, observed)
|
||||
|
||||
|
@ -1786,11 +1809,11 @@ class NumPyTests(PySparkTestCase):
|
|||
|
||||
if __name__ == "__main__":
|
||||
if not _have_scipy:
|
||||
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
|
||||
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
|
||||
if not _have_numpy:
|
||||
print "NOTE: Skipping NumPy tests as it does not seem to be installed"
|
||||
print("NOTE: Skipping NumPy tests as it does not seem to be installed")
|
||||
unittest.main()
|
||||
if not _have_scipy:
|
||||
print "NOTE: SciPy tests were skipped as it does not seem to be installed"
|
||||
print("NOTE: SciPy tests were skipped as it does not seem to be installed")
|
||||
if not _have_numpy:
|
||||
print "NOTE: NumPy tests were skipped as it does not seem to be installed"
|
||||
print("NOTE: NumPy tests were skipped as it does not seem to be installed")
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
"""
|
||||
Worker that receives input from Piped RDD.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
@ -37,9 +38,9 @@ utf8_deserializer = UTF8Deserializer()
|
|||
|
||||
def report_times(outfile, boot, init, finish):
|
||||
write_int(SpecialLengths.TIMING_DATA, outfile)
|
||||
write_long(1000 * boot, outfile)
|
||||
write_long(1000 * init, outfile)
|
||||
write_long(1000 * finish, outfile)
|
||||
write_long(int(1000 * boot), outfile)
|
||||
write_long(int(1000 * init), outfile)
|
||||
write_long(int(1000 * finish), outfile)
|
||||
|
||||
|
||||
def add_path(path):
|
||||
|
@ -72,6 +73,9 @@ def main(infile, outfile):
|
|||
for _ in range(num_python_includes):
|
||||
filename = utf8_deserializer.loads(infile)
|
||||
add_path(os.path.join(spark_files_dir, filename))
|
||||
if sys.version > '3':
|
||||
import importlib
|
||||
importlib.invalidate_caches()
|
||||
|
||||
# fetch names and values of broadcast variables
|
||||
num_broadcast_variables = read_int(infile)
|
||||
|
@ -106,14 +110,14 @@ def main(infile, outfile):
|
|||
except Exception:
|
||||
try:
|
||||
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
|
||||
write_with_length(traceback.format_exc(), outfile)
|
||||
write_with_length(traceback.format_exc().encode("utf-8"), outfile)
|
||||
except IOError:
|
||||
# JVM close the socket
|
||||
pass
|
||||
except Exception:
|
||||
# Write the error to stderr if it happened while serializing
|
||||
print >> sys.stderr, "PySpark worker failed with exception:"
|
||||
print >> sys.stderr, traceback.format_exc()
|
||||
print("PySpark worker failed with exception:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
exit(-1)
|
||||
finish_time = time.time()
|
||||
report_times(outfile, boot_time, init_time, finish_time)
|
||||
|
|
|
@ -66,7 +66,7 @@ function run_core_tests() {
|
|||
|
||||
function run_sql_tests() {
|
||||
echo "Run sql tests ..."
|
||||
run_test "pyspark/sql/types.py"
|
||||
run_test "pyspark/sql/_types.py"
|
||||
run_test "pyspark/sql/context.py"
|
||||
run_test "pyspark/sql/dataframe.py"
|
||||
run_test "pyspark/sql/functions.py"
|
||||
|
@ -136,6 +136,19 @@ run_mllib_tests
|
|||
run_ml_tests
|
||||
run_streaming_tests
|
||||
|
||||
# Try to test with Python 3
|
||||
if [ $(which python3.4) ]; then
|
||||
export PYSPARK_PYTHON="python3.4"
|
||||
echo "Testing with Python3.4 version:"
|
||||
$PYSPARK_PYTHON --version
|
||||
|
||||
run_core_tests
|
||||
run_sql_tests
|
||||
run_mllib_tests
|
||||
run_ml_tests
|
||||
run_streaming_tests
|
||||
fi
|
||||
|
||||
# Try to test with PyPy
|
||||
if [ $(which pypy) ]; then
|
||||
export PYSPARK_PYTHON="pypy"
|
||||
|
|
Binary file not shown.
BIN
python/test_support/userlib-0.1.zip
Normal file
BIN
python/test_support/userlib-0.1.zip
Normal file
Binary file not shown.
Loading…
Reference in a new issue