[SPARK-6949] [SQL] [PySpark] Support Date/Timestamp in Column expression
This PR enable auto_convert in JavaGateway, then we could register a converter for a given types, for example, date and datetime. There are two bugs related to auto_convert, see [1] and [2], we workaround it in this PR. [1] https://github.com/bartdag/py4j/issues/160 [2] https://github.com/bartdag/py4j/issues/161 cc rxin JoshRosen Author: Davies Liu <davies@databricks.com> Closes #5570 from davies/py4j_date and squashes the following commits: eb4fa53 [Davies Liu] fix tests in python 3 d17d634 [Davies Liu] rollback changes in mllib 2e7566d [Davies Liu] convert tuple into ArrayList ceb3779 [Davies Liu] Update rdd.py 3c373f3 [Davies Liu] support date and datetime by auto_convert cb094ff [Davies Liu] enable auto convert
This commit is contained in:
parent
8136810dfa
commit
ab9128fb7e
|
@ -23,8 +23,6 @@ import sys
|
|||
from threading import Lock
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from py4j.java_collections import ListConverter
|
||||
|
||||
from pyspark import accumulators
|
||||
from pyspark.accumulators import Accumulator
|
||||
from pyspark.broadcast import Broadcast
|
||||
|
@ -643,7 +641,6 @@ class SparkContext(object):
|
|||
rdds = [x._reserialize() for x in rdds]
|
||||
first = rdds[0]._jrdd
|
||||
rest = [x._jrdd for x in rdds[1:]]
|
||||
rest = ListConverter().convert(rest, self._gateway._gateway_client)
|
||||
return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer)
|
||||
|
||||
def broadcast(self, value):
|
||||
|
@ -846,13 +843,12 @@ class SparkContext(object):
|
|||
"""
|
||||
if partitions is None:
|
||||
partitions = range(rdd._jrdd.partitions().size())
|
||||
javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)
|
||||
|
||||
# Implementation note: This is implemented as a mapPartitions followed
|
||||
# by runJob() in order to avoid having to pass a Python lambda into
|
||||
# SparkContext#runJob.
|
||||
mappedRDD = rdd.mapPartitions(partitionFunc)
|
||||
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
|
||||
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions,
|
||||
allowLocal)
|
||||
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
|
||||
|
||||
|
|
|
@ -17,17 +17,30 @@
|
|||
|
||||
import atexit
|
||||
import os
|
||||
import sys
|
||||
import select
|
||||
import signal
|
||||
import shlex
|
||||
import socket
|
||||
import platform
|
||||
from subprocess import Popen, PIPE
|
||||
|
||||
if sys.version >= '3':
|
||||
xrange = range
|
||||
|
||||
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
|
||||
from py4j.java_collections import ListConverter
|
||||
|
||||
from pyspark.serializers import read_int
|
||||
|
||||
|
||||
# patching ListConverter, or it will convert bytearray into Java ArrayList
|
||||
def can_convert_list(self, obj):
|
||||
return isinstance(obj, (list, tuple, xrange))
|
||||
|
||||
ListConverter.can_convert = can_convert_list
|
||||
|
||||
|
||||
def launch_gateway():
|
||||
if "PYSPARK_GATEWAY_PORT" in os.environ:
|
||||
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
|
||||
|
@ -92,7 +105,7 @@ def launch_gateway():
|
|||
atexit.register(killChild)
|
||||
|
||||
# Connect to the gateway
|
||||
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)
|
||||
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
|
||||
|
||||
# Import the classes used by PySpark
|
||||
java_import(gateway.jvm, "org.apache.spark.SparkConf")
|
||||
|
|
|
@ -2267,6 +2267,9 @@ def _prepare_for_python_RDD(sc, command, obj=None):
|
|||
# The broadcast will have same life cycle as created PythonRDD
|
||||
broadcast = sc.broadcast(pickled_command)
|
||||
pickled_command = ser.dumps(broadcast)
|
||||
# There is a bug in py4j.java_gateway.JavaClass with auto_convert
|
||||
# https://github.com/bartdag/py4j/issues/161
|
||||
# TODO: use auto_convert once py4j fix the bug
|
||||
broadcast_vars = ListConverter().convert(
|
||||
[x._jbroadcast for x in sc._pickled_broadcast_vars],
|
||||
sc._gateway._gateway_client)
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
import sys
|
||||
import decimal
|
||||
import time
|
||||
import datetime
|
||||
import keyword
|
||||
import warnings
|
||||
|
@ -30,6 +31,9 @@ if sys.version >= "3":
|
|||
long = int
|
||||
unicode = str
|
||||
|
||||
from py4j.protocol import register_input_converter
|
||||
from py4j.java_gateway import JavaClass
|
||||
|
||||
__all__ = [
|
||||
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
|
||||
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
|
||||
|
@ -1237,6 +1241,29 @@ class Row(tuple):
|
|||
return "<Row(%s)>" % ", ".join(self)
|
||||
|
||||
|
||||
class DateConverter(object):
|
||||
def can_convert(self, obj):
|
||||
return isinstance(obj, datetime.date)
|
||||
|
||||
def convert(self, obj, gateway_client):
|
||||
Date = JavaClass("java.sql.Date", gateway_client)
|
||||
return Date.valueOf(obj.strftime("%Y-%m-%d"))
|
||||
|
||||
|
||||
class DatetimeConverter(object):
|
||||
def can_convert(self, obj):
|
||||
return isinstance(obj, datetime.datetime)
|
||||
|
||||
def convert(self, obj, gateway_client):
|
||||
Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
|
||||
return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000)
|
||||
|
||||
|
||||
# datetime is a subclass of date, we should register DatetimeConverter first
|
||||
register_input_converter(DatetimeConverter())
|
||||
register_input_converter(DateConverter())
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.context import SparkContext
|
||||
|
|
|
@ -25,7 +25,6 @@ else:
|
|||
from itertools import imap as map
|
||||
|
||||
from py4j.protocol import Py4JError
|
||||
from py4j.java_collections import MapConverter
|
||||
|
||||
from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
|
||||
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
|
||||
|
@ -442,15 +441,13 @@ class SQLContext(object):
|
|||
if source is None:
|
||||
source = self.getConf("spark.sql.sources.default",
|
||||
"org.apache.spark.sql.parquet")
|
||||
joptions = MapConverter().convert(options,
|
||||
self._sc._gateway._gateway_client)
|
||||
if schema is None:
|
||||
df = self._ssql_ctx.load(source, joptions)
|
||||
df = self._ssql_ctx.load(source, options)
|
||||
else:
|
||||
if not isinstance(schema, StructType):
|
||||
raise TypeError("schema should be StructType")
|
||||
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
|
||||
df = self._ssql_ctx.load(source, scala_datatype, joptions)
|
||||
df = self._ssql_ctx.load(source, scala_datatype, options)
|
||||
return DataFrame(df, self)
|
||||
|
||||
def createExternalTable(self, tableName, path=None, source=None,
|
||||
|
@ -471,16 +468,14 @@ class SQLContext(object):
|
|||
if source is None:
|
||||
source = self.getConf("spark.sql.sources.default",
|
||||
"org.apache.spark.sql.parquet")
|
||||
joptions = MapConverter().convert(options,
|
||||
self._sc._gateway._gateway_client)
|
||||
if schema is None:
|
||||
df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
|
||||
df = self._ssql_ctx.createExternalTable(tableName, source, options)
|
||||
else:
|
||||
if not isinstance(schema, StructType):
|
||||
raise TypeError("schema should be StructType")
|
||||
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
|
||||
df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
|
||||
joptions)
|
||||
options)
|
||||
return DataFrame(df, self)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
|
|
|
@ -25,8 +25,6 @@ if sys.version >= '3':
|
|||
else:
|
||||
from itertools import imap as map
|
||||
|
||||
from py4j.java_collections import ListConverter, MapConverter
|
||||
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
|
||||
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
|
||||
|
@ -186,9 +184,7 @@ class DataFrame(object):
|
|||
source = self.sql_ctx.getConf("spark.sql.sources.default",
|
||||
"org.apache.spark.sql.parquet")
|
||||
jmode = self._java_save_mode(mode)
|
||||
joptions = MapConverter().convert(options,
|
||||
self.sql_ctx._sc._gateway._gateway_client)
|
||||
self._jdf.saveAsTable(tableName, source, jmode, joptions)
|
||||
self._jdf.saveAsTable(tableName, source, jmode, options)
|
||||
|
||||
def save(self, path=None, source=None, mode="error", **options):
|
||||
"""Saves the contents of the :class:`DataFrame` to a data source.
|
||||
|
@ -211,9 +207,7 @@ class DataFrame(object):
|
|||
source = self.sql_ctx.getConf("spark.sql.sources.default",
|
||||
"org.apache.spark.sql.parquet")
|
||||
jmode = self._java_save_mode(mode)
|
||||
joptions = MapConverter().convert(options,
|
||||
self._sc._gateway._gateway_client)
|
||||
self._jdf.save(source, jmode, joptions)
|
||||
self._jdf.save(source, jmode, options)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
|
@ -819,7 +813,6 @@ class DataFrame(object):
|
|||
value = float(value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client)
|
||||
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
|
||||
elif subset is None:
|
||||
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
|
||||
|
@ -932,9 +925,7 @@ class GroupedData(object):
|
|||
"""
|
||||
assert exprs, "exprs should not be empty"
|
||||
if len(exprs) == 1 and isinstance(exprs[0], dict):
|
||||
jmap = MapConverter().convert(exprs[0],
|
||||
self.sql_ctx._sc._gateway._gateway_client)
|
||||
jdf = self._jdf.agg(jmap)
|
||||
jdf = self._jdf.agg(exprs[0])
|
||||
else:
|
||||
# Columns
|
||||
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
|
||||
|
@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None):
|
|||
"""
|
||||
if converter:
|
||||
cols = [converter(c) for c in cols]
|
||||
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
|
||||
return sc._jvm.PythonUtils.toSeq(jcols)
|
||||
return sc._jvm.PythonUtils.toSeq(cols)
|
||||
|
||||
|
||||
def _unary_op(name, doc="unary operator"):
|
||||
|
|
|
@ -26,6 +26,7 @@ import shutil
|
|||
import tempfile
|
||||
import pickle
|
||||
import functools
|
||||
import datetime
|
||||
|
||||
import py4j
|
||||
|
||||
|
@ -464,6 +465,16 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
self.assertEqual(_infer_type(2**61), LongType())
|
||||
self.assertEqual(_infer_type(2**71), LongType())
|
||||
|
||||
def test_filter_with_datetime(self):
|
||||
time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
|
||||
date = time.date()
|
||||
row = Row(date=date, time=time)
|
||||
df = self.sqlCtx.createDataFrame([row])
|
||||
self.assertEqual(1, df.filter(df.date == date).count())
|
||||
self.assertEqual(1, df.filter(df.time == time).count())
|
||||
self.assertEqual(0, df.filter(df.date > date).count())
|
||||
self.assertEqual(0, df.filter(df.time > time).count())
|
||||
|
||||
def test_dropna(self):
|
||||
schema = StructType([
|
||||
StructField("name", StringType(), True),
|
||||
|
|
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||
import os
|
||||
import sys
|
||||
|
||||
from py4j.java_collections import ListConverter
|
||||
from py4j.java_gateway import java_import, JavaObject
|
||||
|
||||
from pyspark import RDD, SparkConf
|
||||
|
@ -305,9 +304,7 @@ class StreamingContext(object):
|
|||
rdds = [self._sc.parallelize(input) for input in rdds]
|
||||
self._check_serializers(rdds)
|
||||
|
||||
jrdds = ListConverter().convert([r._jrdd for r in rdds],
|
||||
SparkContext._gateway._gateway_client)
|
||||
queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
|
||||
queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds])
|
||||
if default:
|
||||
default = default._reserialize(rdds[0]._jrdd_deserializer)
|
||||
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
|
||||
|
@ -322,8 +319,7 @@ class StreamingContext(object):
|
|||
the transform function parameter will be the same as the order
|
||||
of corresponding DStreams in the list.
|
||||
"""
|
||||
jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
|
||||
SparkContext._gateway._gateway_client)
|
||||
jdstreams = [d._jdstream for d in dstreams]
|
||||
# change the final serializer to sc.serializer
|
||||
func = TransformFunction(self._sc,
|
||||
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
|
||||
|
@ -346,6 +342,5 @@ class StreamingContext(object):
|
|||
if len(set(s._slideDuration for s in dstreams)) > 1:
|
||||
raise ValueError("All DStreams should have same slide duration")
|
||||
first = dstreams[0]
|
||||
jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
|
||||
SparkContext._gateway._gateway_client)
|
||||
jrest = [d._jdstream for d in dstreams[1:]]
|
||||
return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from py4j.java_collections import MapConverter
|
||||
from py4j.java_gateway import java_import, Py4JError, Py4JJavaError
|
||||
from py4j.java_gateway import Py4JJavaError
|
||||
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
from pyspark.serializers import PairDeserializer, NoOpSerializer
|
||||
|
@ -57,8 +56,6 @@ class KafkaUtils(object):
|
|||
})
|
||||
if not isinstance(topics, dict):
|
||||
raise TypeError("topics should be dict")
|
||||
jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client)
|
||||
jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client)
|
||||
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
|
||||
|
||||
try:
|
||||
|
@ -66,7 +63,7 @@ class KafkaUtils(object):
|
|||
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
|
||||
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
|
||||
helper = helperClass.newInstance()
|
||||
jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel)
|
||||
jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
|
||||
except Py4JJavaError as e:
|
||||
# TODO: use --jar once it also work on driver
|
||||
if 'ClassNotFoundException' in str(e.java_exception):
|
||||
|
|
|
@ -24,8 +24,6 @@ import tempfile
|
|||
import struct
|
||||
from functools import reduce
|
||||
|
||||
from py4j.java_collections import MapConverter
|
||||
|
||||
from pyspark.context import SparkConf, SparkContext, RDD
|
||||
from pyspark.streaming.context import StreamingContext
|
||||
from pyspark.streaming.kafka import KafkaUtils
|
||||
|
@ -581,11 +579,9 @@ class KafkaStreamTests(PySparkStreamingTestCase):
|
|||
"""Test the Python Kafka stream API."""
|
||||
topic = "topic1"
|
||||
sendData = {"a": 3, "b": 5, "c": 10}
|
||||
jSendData = MapConverter().convert(sendData,
|
||||
self.ssc.sparkContext._gateway._gateway_client)
|
||||
|
||||
self._kafkaTestUtils.createTopic(topic)
|
||||
self._kafkaTestUtils.sendMessages(topic, jSendData)
|
||||
self._kafkaTestUtils.sendMessages(topic, sendData)
|
||||
|
||||
stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
|
||||
"test-streaming-consumer", {topic: 1},
|
||||
|
|
Loading…
Reference in a new issue