diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a33f6dcf31..24905f1c97 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -475,24 +475,30 @@ class SparkContext(object): return xrange(getStart(split), getStart(split + 1), step) return self.parallelize([], numSlices).mapPartitionsWithIndex(f) - # Calling the Java parallelize() method with an ArrayList is too slow, - # because it sends O(n) Py4J commands. As an alternative, serialized - # objects are written to a file and loaded through textFile(). + + # Make sure we distribute data evenly if it's smaller than self.batchSize + if "__len__" not in dir(c): + c = list(c) # Make it a list so we can compute its length + batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) + jrdd = self._serialize_to_jvm(c, numSlices, serializer) + return RDD(jrdd, self, serializer) + + def _serialize_to_jvm(self, data, parallelism, serializer): + """ + Calling the Java parallelize() method with an ArrayList is too slow, + because it sends O(n) Py4J commands. As an alternative, serialized + objects are written to a file and loaded through textFile(). + """ tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) try: - # Make sure we distribute data evenly if it's smaller than self.batchSize - if "__len__" not in dir(c): - c = list(c) # Make it a list so we can compute its length - batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) - serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - serializer.dump_stream(c, tempFile) + serializer.dump_stream(data, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + return readRDDFromFile(self._jsc, tempFile.name, parallelism) finally: # readRDDFromFile eagerily reads the file so we can delete right after. os.unlink(tempFile.name) - return RDD(jrdd, self, serializer) def pickleFile(self, name, minPartitions=None): """ diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3c783ae541..3e704fe9bf 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -121,6 +121,7 @@ def launch_gateway(conf=None): java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d7979f095d..e0afdafbfc 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -214,6 +214,13 @@ class ArrowSerializer(FramedSerializer): def _create_batch(series): + """ + Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. + + :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) + :return: Arrow RecordBatch + """ + from pyspark.sql.types import _check_series_convert_timestamps_internal import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] @@ -229,7 +236,8 @@ def _create_batch(series): # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 return _check_series_convert_timestamps_internal(s.fillna(0))\ .values.astype('datetime64[us]', copy=False) - elif t == pa.date32(): + # NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1 + elif t is not None and t == pa.date32(): # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8 return s.dt.date elif t is None or s.dtype == t.to_pandas_dtype(): diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index d1d0b8b8fe..589365b083 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -25,7 +25,7 @@ if sys.version >= '3': basestring = unicode = str xrange = range else: - from itertools import imap as map + from itertools import izip as zip, imap as map from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix @@ -417,12 +417,12 @@ class SparkSession(object): data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema - def _get_numpy_record_dtypes(self, rec): + def _get_numpy_record_dtype(self, rec): """ Used when converting a pandas.DataFrame to Spark using to_records(), this will correct - the dtypes of records so they can be properly loaded into Spark. - :param rec: a numpy record to check dtypes - :return corrected dtypes for a numpy.record or None if no correction needed + the dtypes of fields in a record so they can be properly loaded into Spark. + :param rec: a numpy record to check field dtypes + :return corrected dtype for a numpy.record or None if no correction needed """ import numpy as np cur_dtypes = rec.dtype @@ -438,28 +438,70 @@ class SparkSession(object): curr_type = 'datetime64[us]' has_rec_fix = True record_type_list.append((str(col_names[i]), curr_type)) - return record_type_list if has_rec_fix else None + return np.dtype(record_type_list) if has_rec_fix else None - def _convert_from_pandas(self, pdf, schema): + def _convert_from_pandas(self, pdf): """ Convert a pandas.DataFrame to list of records that can be used to make a DataFrame - :return tuple of list of records and schema + :return list of records """ - # If no schema supplied by user then get the names of columns only - if schema is None: - schema = [str(x) for x in pdf.columns] # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) # Check if any columns need to be fixed for Spark to infer properly if len(np_records) > 0: - record_type_list = self._get_numpy_record_dtypes(np_records[0]) - if record_type_list is not None: - return [r.astype(record_type_list).tolist() for r in np_records], schema + record_dtype = self._get_numpy_record_dtype(np_records[0]) + if record_dtype is not None: + return [r.astype(record_dtype).tolist() for r in np_records] # Convert list of numpy records to python lists - return [r.tolist() for r in np_records], schema + return [r.tolist() for r in np_records] + + def _create_from_pandas_with_arrow(self, pdf, schema): + """ + Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting + to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the + data types will be used to coerce the data in Pandas to Arrow conversion. + """ + from pyspark.serializers import ArrowSerializer, _create_batch + from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + + # Determine arrow types to coerce data when creating batches + if isinstance(schema, StructType): + arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] + elif isinstance(schema, DataType): + raise ValueError("Single data type %s is not supported with Arrow" % str(schema)) + else: + # Any timestamps must be coerced to be compatible with Spark + arrow_types = [to_arrow_type(TimestampType()) + if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None + for t in pdf.dtypes] + + # Slice the DataFrame to be batched + step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up + pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) + + # Create Arrow record batches + batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]) + for pdf_slice in pdf_slices] + + # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) + if isinstance(schema, (list, tuple)): + struct = from_arrow_schema(batches[0].schema) + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + # Create the Spark DataFrame directly from the Arrow data and schema + jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer()) + jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( + jrdd, schema.json(), self._wrapped._jsqlContext) + df = DataFrame(jdf, self._wrapped) + df._schema = schema + return df @since(2.0) @ignore_unicode_prefix @@ -557,7 +599,19 @@ class SparkSession(object): except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - data, schema = self._convert_from_pandas(data, schema) + + # If no schema supplied by user then get the names of columns only + if schema is None: + schema = [str(x) for x in data.columns] + + if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ + and len(data) > 0: + try: + return self._create_from_pandas_with_arrow(data, schema) + except Exception as e: + warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) + # Fallback to create DataFrame without arrow if raise some exception + data = self._convert_from_pandas(data) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True @@ -576,7 +630,7 @@ class SparkSession(object): verify_func(obj) return obj, else: - if isinstance(schema, list): + if isinstance(schema, (list, tuple)): schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] prepare = lambda obj: obj diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4819f629c5..6356d938db 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3127,9 +3127,9 @@ class ArrowTests(ReusedSQLTestCase): StructField("5_double_t", DoubleType(), True), StructField("6_date_t", DateType(), True), StructField("7_timestamp_t", TimestampType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3145,6 +3145,17 @@ class ArrowTests(ReusedSQLTestCase): ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) self.assertTrue(df_without.equals(df_with_arrow), msg=msg) + def create_pandas_data_frame(self): + import pandas as pd + import numpy as np + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + return pd.DataFrame(data=data_dict) + def test_unsupported_datatype(self): schema = StructType([StructField("decimal", DecimalType(), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) @@ -3161,21 +3172,15 @@ class ArrowTests(ReusedSQLTestCase): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - pdf = df.toPandas() - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + try: + pdf = df.toPandas() + finally: + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) def test_pandas_round_trip(self): - import pandas as pd - import numpy as np - data_dict = {} - for j, name in enumerate(self.schema.names): - data_dict[name] = [self.data[i][j] for i in range(len(self.data))] - # need to convert these to numpy types first - data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) - data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) - pdf = pd.DataFrame(data=data_dict) + pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) @@ -3187,6 +3192,62 @@ class ArrowTests(ReusedSQLTestCase): self.assertEqual(pdf.columns[0], "i") self.assertTrue(pdf.empty) + def test_createDataFrame_toggle(self): + pdf = self.create_pandas_data_frame() + self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") + try: + df_no_arrow = self.spark.createDataFrame(pdf) + finally: + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + df_arrow = self.spark.createDataFrame(pdf) + self.assertEquals(df_no_arrow.collect(), df_arrow.collect()) + + def test_createDataFrame_with_schema(self): + pdf = self.create_pandas_data_frame() + df = self.spark.createDataFrame(pdf, schema=self.schema) + self.assertEquals(self.schema, df.schema) + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_createDataFrame_with_incorrect_schema(self): + pdf = self.create_pandas_data_frame() + wrong_schema = StructType(list(reversed(self.schema))) + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"): + self.spark.createDataFrame(pdf, schema=wrong_schema) + + def test_createDataFrame_with_names(self): + pdf = self.create_pandas_data_frame() + # Test that schema as a list of column names gets applied + df = self.spark.createDataFrame(pdf, schema=list('abcdefg')) + self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + # Test that schema as tuple of column names gets applied + df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) + self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + + def test_createDataFrame_with_single_data_type(self): + import pandas as pd + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"): + self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") + + def test_createDataFrame_does_not_modify_input(self): + # Some series get converted for Spark to consume, this makes sure input is unchanged + pdf = self.create_pandas_data_frame() + # Use a nanosecond value to make sure it is not truncated + pdf.ix[0, '7_timestamp_t'] = 1 + # Integers with nulls will get NaNs filled with 0 and will be casted + pdf.ix[1, '2_int_t'] = None + pdf_copy = pdf.copy(deep=True) + self.spark.createDataFrame(pdf, schema=self.schema) + self.assertTrue(pdf.equals(pdf_copy)) + + def test_schema_conversion_roundtrip(self): + from pyspark.sql.types import from_arrow_schema, to_arrow_schema + arrow_schema = to_arrow_schema(self.schema) + schema_rt = from_arrow_schema(arrow_schema) + self.assertEquals(self.schema, schema_rt) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedSQLTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 7dd8fa0416..fe62f60dd6 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1629,6 +1629,55 @@ def to_arrow_type(dt): return arrow_type +def to_arrow_schema(schema): + """ Convert a schema from Spark to Arrow + """ + import pyarrow as pa + fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + for field in schema] + return pa.schema(fields) + + +def from_arrow_type(at): + """ Convert pyarrow type to Spark data type. + """ + # TODO: newer pyarrow has is_boolean(at) functions that would be better to check type + import pyarrow as pa + if at == pa.bool_(): + spark_type = BooleanType() + elif at == pa.int8(): + spark_type = ByteType() + elif at == pa.int16(): + spark_type = ShortType() + elif at == pa.int32(): + spark_type = IntegerType() + elif at == pa.int64(): + spark_type = LongType() + elif at == pa.float32(): + spark_type = FloatType() + elif at == pa.float64(): + spark_type = DoubleType() + elif type(at) == pa.DecimalType: + spark_type = DecimalType(precision=at.precision, scale=at.scale) + elif at == pa.string(): + spark_type = StringType() + elif at == pa.date32(): + spark_type = DateType() + elif type(at) == pa.TimestampType: + spark_type = TimestampType() + else: + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) + return spark_type + + +def from_arrow_schema(arrow_schema): + """ Convert schema from Arrow to Spark. + """ + return StructType( + [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) + for field in arrow_schema]) + + def _check_dataframe_localize_timestamps(pdf): """ Convert timezone aware timestamps to timezone-naive in local time diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 4d5ce0bb60..b33760b1ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.api.python +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.DataType private[sql] object PythonSQLUtils { @@ -29,4 +32,19 @@ private[sql] object PythonSQLUtils { def listBuiltinFunctionInfos(): Array[ExpressionInfo] = { FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray } + + /** + * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * + * @param payloadRDD A JavaRDD of ArrowPayloads. + * @param schemaString JSON Formatted Schema for ArrowPayloads. + * @param sqlContext The active [[SQLContext]]. + * @return The converted [[DataFrame]]. + */ + def arrowPayloadToDataFrame( + payloadRDD: JavaRDD[Array[Byte]], + schemaString: String, + sqlContext: SQLContext): DataFrame = { + ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 05ea1517fc..3cafb344ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -29,6 +29,8 @@ import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.TaskContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ @@ -204,4 +206,16 @@ private[sql] object ArrowConverters { reader.close() } } + + private[sql] def toDataFrame( + payloadRDD: JavaRDD[Array[Byte]], + schemaString: String, + sqlContext: SQLContext): DataFrame = { + val rdd = payloadRDD.rdd.mapPartitions { iter => + val context = TaskContext.get() + ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + } + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + sqlContext.internalCreateDataFrame(rdd, schema) + } }