From 74d8d3d928cc9a7386b68588ac89ae042847d146 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 8 Jul 2015 18:22:53 -0700 Subject: [PATCH] [SPARK-8450] [SQL] [PYSARK] cleanup type converter for Python DataFrame This PR fixes the converter for Python DataFrame, especially for DecimalType Closes #7106 Author: Davies Liu Closes #7131 from davies/decimal_python and squashes the following commits: 4d3c234 [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 20531d6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 7d73168 [Davies Liu] fix conflit 6cdd86a [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 7104e97 [Davies Liu] improve type infer 9cd5a21 [Davies Liu] run python tests with SPARK_PREPEND_CLASSES 829a05b [Davies Liu] fix UDT in python c99e8c5 [Davies Liu] fix mima c46814a [Davies Liu] convert decimal for Python DataFrames --- .../apache/spark/mllib/linalg/Matrices.scala | 10 +- .../apache/spark/mllib/linalg/Vectors.scala | 16 +--- project/MimaExcludes.scala | 5 +- python/pyspark/sql/tests.py | 13 +++ python/pyspark/sql/types.py | 4 + python/run-tests.py | 3 +- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 28 +----- .../spark/sql/execution/pythonUDFs.scala | 95 ++++++++++--------- 9 files changed, 84 insertions(+), 94 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 75e7004464..0df0766340 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.Row -import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ /** * Trait for a local matrix. @@ -147,7 +147,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { )) } - override def serialize(obj: Any): Row = { + override def serialize(obj: Any): InternalRow = { val row = new GenericMutableRow(7) obj match { case sm: SparseMatrix => @@ -173,9 +173,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def deserialize(datum: Any): Matrix = { datum match { - // TODO: something wrong with UDT serialization, should never happen. - case m: Matrix => m - case row: Row => + case row: InternalRow => require(row.length == 7, s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7") val tpe = row.getByte(0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c9c27425d2..e048b01d92 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -28,7 +28,7 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ @@ -175,7 +175,7 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) } - override def serialize(obj: Any): Row = { + override def serialize(obj: Any): InternalRow = { obj match { case SparseVector(size, indices, values) => val row = new GenericMutableRow(4) @@ -191,17 +191,12 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { row.setNullAt(2) row.update(3, values.toSeq) row - // TODO: There are bugs in UDT serialization because we don't have a clear separation between - // TODO: internal SQL types and language specific types (including UDT). UDT serialize and - // TODO: deserialize may get called twice. See SPARK-7186. - case row: Row => - row } } override def deserialize(datum: Any): Vector = { datum match { - case row: Row => + case row: InternalRow => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") val tpe = row.getByte(0) @@ -215,11 +210,6 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val values = row.getAs[Iterable[Double]](3).toArray new DenseVector(values) } - // TODO: There are bugs in UDT serialization because we don't have a clear separation between - // TODO: internal SQL types and language specific types (including UDT). UDT serialize and - // TODO: deserialize may get called twice. See SPARK-7186. - case v: Vector => - v } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 57a86bf8de..821aadd477 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -63,7 +63,10 @@ object MimaExcludes { // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), // Parquet support is considered private. - excludePackage("org.apache.spark.sql.parquet") + excludePackage("org.apache.spark.sql.parquet"), + // local function inside a method + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1") ) ++ Seq( // SPARK-8479 Add numNonzeros and numActives to Matrix. ProblemFilters.exclude[MissingMethodProblem]( diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 333378c7f1..66827d4885 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -700,6 +700,19 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(now - now1 < datetime.timedelta(0.001)) self.assertTrue(now - utcnow1 < datetime.timedelta(0.001)) + def test_decimal(self): + from decimal import Decimal + schema = StructType([StructField("decimal", DecimalType(10, 5))]) + df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema) + row = df.select(df.decimal + 1).first() + self.assertEqual(row[0], Decimal("4.14159")) + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.parquet(tmpPath) + df2 = self.sqlCtx.read.parquet(tmpPath) + row = df2.first() + self.assertEqual(row[0], Decimal("3.14159")) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 160df40d65..7e64cb0b54 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1069,6 +1069,10 @@ def _verify_type(obj, dataType): if obj is None: return + # StringType can work with any types + if isinstance(dataType, StringType): + return + if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): raise ValueError("%r is not an instance of type %r" % (obj, dataType)) diff --git a/python/run-tests.py b/python/run-tests.py index 7638854def..cc56077937 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -72,7 +72,8 @@ LOGGER = logging.getLogger() def run_individual_python_test(test_name, pyspark_python): - env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} + env = dict(os.environ) + env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}) LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index eeefc85255..d9f987ae02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1549,8 +1549,8 @@ class DataFrame private[sql]( * Converts a JavaRDD to a PythonRDD. */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + val structType = schema // capture it for closure + val jrdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)).toJavaRDD() SerDeUtil.javaToPython(jrdd) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 079f31ab8f..477dea9164 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1044,33 +1044,7 @@ class SQLContext(@transient val sparkContext: SparkContext) rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - def needsConversion(dataType: DataType): Boolean = dataType match { - case ByteType => true - case ShortType => true - case LongType => true - case FloatType => true - case DateType => true - case TimestampType => true - case StringType => true - case ArrayType(_, _) => true - case MapType(_, _, _) => true - case StructType(_) => true - case udt: UserDefinedType[_] => needsConversion(udt.sqlType) - case other => false - } - - val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { - rdd.map(m => m.zip(schema.fields).map { - case (value, field) => EvaluatePython.fromJava(value, field.dataType) - }) - } else { - rdd - } - - val rowRdd = convertedRdd.mapPartitions { iter => - iter.map { m => new GenericInternalRow(m): InternalRow} - } - + val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 6946e798b7..1c8130b07c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -24,20 +24,19 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} -import org.apache.spark.{Accumulator, Logging => SparkLogging} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Accumulator, Logging => SparkLogging} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -125,59 +124,86 @@ object EvaluatePython { new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) /** - * Helper for converting a Scala object to a java suitable for pyspark serialization. + * Helper for converting from Catalyst type to java type suitable for Pyrolite. */ def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (row: Row, struct: StructType) => + case (row: InternalRow, struct: StructType) => val fields = struct.fields.map(field => field.dataType) - row.toSeq.zip(fields).map { - case (obj, dataType) => toJava(obj, dataType) - }.toArray + rowToArray(row, fields) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava - case (list: JList[_], array: ArrayType) => - list.map(x => toJava(x, array.elementType)).asJava - case (arr, array: ArrayType) if arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) }.asJava - case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) + + case (d: Decimal, _) => d.toJavaBigDecimal + case (s: UTF8String, StringType) => s.toString - // Pyrolite can handle Timestamp and Decimal case (other, _) => other } /** * Convert Row into Java Array (for pickled into Python) */ - def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { + def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = { // TODO: this is slow! row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray } - // Converts value to the type specified by the data type. - // Because Python does not have data types for TimestampType, FloatType, ShortType, and - // ByteType, we need to explicitly convert values in columns of these data types to the desired - // JVM data types. + /** + * Converts `obj` to the type specified by the data type, or returns null if the type of obj is + * unexpected. Because Python doesn't enforce the type. + */ def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - // TODO: We should check nullable case (null, _) => null + case (c: Boolean, BooleanType) => c + + case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte + + case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + + case (c: Int, IntegerType) => c + case (c: Long, IntegerType) => c.toInt + + case (c: Int, LongType) => c.toLong + case (c: Long, LongType) => c + + case (c: Double, FloatType) => c.toFloat + + case (c: Double, DoubleType) => c + + case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c) + + case (c: Int, DateType) => c + + case (c: Long, TimestampType) => c + + case (c: String, StringType) => UTF8String.fromString(c) + case (c, StringType) => + // If we get here, c is not a string. Call toString on it. + UTF8String.fromString(c.toString) + + case (c: String, BinaryType) => c.getBytes("utf-8") + case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => fromJava(e, elementType)}: Seq[Any] + c.map { e => fromJava(e, elementType)}.toSeq case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any] + c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) @@ -188,30 +214,11 @@ object EvaluatePython { case (e, f) => fromJava(e, f.dataType) }) - case (c: java.util.Calendar, DateType) => - DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis)) + case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) - case (c: java.util.Calendar, TimestampType) => - c.getTimeInMillis * 10000L - case (t: java.sql.Timestamp, TimestampType) => - DateTimeUtils.fromJavaTimestamp(t) - - case (_, udt: UserDefinedType[_]) => - fromJava(obj, udt.sqlType) - - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort - case (c: Long, IntegerType) => c.toInt - case (c: Int, LongType) => c.toLong - case (c: Double, FloatType) => c.toFloat - case (c: String, StringType) => UTF8String.fromString(c) - case (c, StringType) => - // If we get here, c is not a string. Call toString on it. - UTF8String.fromString(c.toString) - - case (c, _) => c + // all other unexpected type should be null, or we will have runtime exception + // TODO(davies): we could improve this by try to cast the object to expected type + case (c, _) => null } }