[SPARK-34545][SQL] Fix issues with valueCompare feature of pyrolite

### What changes were proposed in this pull request?

pyrolite 4.21 introduced and enabled value comparison by default (`valueCompare=true`) during object memoization and serialization: https://github.com/irmen/Pyrolite/blob/pyrolite-4.21/java/src/main/java/net/razorvine/pickle/Pickler.java#L112-L122
This change has undesired effect when we serialize a row (actually `GenericRowWithSchema`) to be passed to python: https://github.com/apache/spark/blob/branch-3.0/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala#L60. A simple example is that
```
new GenericRowWithSchema(Array(1.0, 1.0), StructType(Seq(StructField("_1", DoubleType), StructField("_2", DoubleType))))
```
and
```
new GenericRowWithSchema(Array(1, 1), StructType(Seq(StructField("_1", IntegerType), StructField("_2", IntegerType))))
```
are currently equal and the second instance is replaced to the short code of the first one during serialization.

### Why are the changes needed?
The above can cause nasty issues like the one in https://issues.apache.org/jira/browse/SPARK-34545 description:

```
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import *
>>>
>>> def udf1(data_type):
        def u1(e):
            return e[0]
        return udf(u1, data_type)
>>>
>>> df = spark.createDataFrame([((1.0, 1.0), (1, 1))], ['c1', 'c2'])
>>>
>>> df = df.withColumn("c3", udf1(DoubleType())("c1"))
>>> df = df.withColumn("c4", udf1(IntegerType())("c2"))
>>>
>>> df.select("c3").show()
+---+
| c3|
+---+
|1.0|
+---+

>>> df.select("c4").show()
+---+
| c4|
+---+
|  1|
+---+

>>> df.select("c3", "c4").show()
+---+----+
| c3|  c4|
+---+----+
|1.0|null|
+---+----+
```
This is because during serialization from JVM to Python `GenericRowWithSchema(1.0, 1.0)` (`c1`) is memoized first and when `GenericRowWithSchema(1, 1)` (`c2`) comes next, it is replaced to some short code of the `c1` (instead of serializing `c2` out) as they are `equal()`. The python functions then runs but the return type of `c4` is expected to be `IntegerType` and if a different type (`DoubleType`) comes back from python then it is discarded: https://github.com/apache/spark/blob/branch-3.0/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala#L108-L113

After this PR:
```
>>> df.select("c3", "c4").show()
+---+---+
| c3| c4|
+---+---+
|1.0|  1|
+---+---+
```

### Does this PR introduce _any_ user-facing change?
Yes, fixes a correctness issue.

### How was this patch tested?
Added new UT + manual tests.

Closes #31682 from peter-toth/SPARK-34545-fix-row-comparison.

Authored-by: Peter Toth <peter.toth@gmail.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
Peter Toth 2021-03-07 19:12:42 -06:00 committed by Sean Owen
parent 9ec8696f11
commit ab8a9a0ceb
4 changed files with 33 additions and 6 deletions

View file

@ -78,7 +78,8 @@ private[spark] object SerDeUtil extends Logging {
* Choose batch size based on size of objects
*/
private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
private val pickle = new Pickler()
private val pickle = new Pickler(/* useMemo = */ true,
/* valueCompare = */ false)
private var batch = 1
private val buffer = new mutable.ArrayBuffer[Any]
@ -131,7 +132,8 @@ private[spark] object SerDeUtil extends Logging {
}
private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
val pickle = new Pickler(/* useMemo = */ true,
/* valueCompare = */ false)
val kt = Try {
pickle.dumps(t._1)
}
@ -182,7 +184,8 @@ private[spark] object SerDeUtil extends Logging {
if (batchSize == 0) {
new AutoBatchedPickler(cleaned)
} else {
val pickle = new Pickler
val pickle = new Pickler(/* useMemo = */ true,
/* valueCompare = */ false)
cleaned.grouped(batchSize).map(batched => pickle.dumps(batched.asJava))
}
}

View file

@ -1313,8 +1313,10 @@ private[spark] abstract class SerDeBase {
def dumps(obj: AnyRef): Array[Byte] = {
obj match {
// Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834.
case array: Array[_] => new Pickler().dumps(array.toSeq.asJava)
case _ => new Pickler().dumps(obj)
case array: Array[_] => new Pickler(/* useMemo = */ true,
/* valueCompare = */ false).dumps(array.toSeq.asJava)
case _ => new Pickler(/* useMemo = */ true,
/* valueCompare = */ false).dumps(obj)
}
}

View file

@ -674,6 +674,17 @@ class UDFTests(ReusedSQLTestCase):
self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution()
.withCachedData().getClass().getSimpleName(), 'InMemoryRelation')
# SPARK-34545
def test_udf_input_serialization_valuecompare_disabled(self):
def f(e):
return e[0]
df = self.spark.createDataFrame([((1.0, 1.0), (1, 1))], ['c1', 'c2'])
result = df.select("*", udf(f, DoubleType())("c1").alias('c3'),
udf(f, IntegerType())("c2").alias('c4'))
self.assertEqual(result.collect(),
[Row(c1=Row(_1=1.0, _2=1.0), c2=Row(_1=1, _2=1), c3=1.0, c4=1)])
class UDFInitializationTests(unittest.TestCase):
def tearDown(self):

View file

@ -46,7 +46,18 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
// enable memo iff we serialize the row with schema (schema and class should be memorized)
val pickle = new Pickler(needConversion)
// pyrolite 4.21+ can lookup objects in its cache by value, but `GenericRowWithSchema` objects,
// that we pass from JVM to Python, don't define their `equals()` to take the type of the
// values or the schema of the row into account. This causes like
// `GenericRowWithSchema(Array(1.0, 1.0),
// StructType(Seq(StructField("_1", DoubleType), StructField("_2", DoubleType))))`
// and
// `GenericRowWithSchema(Array(1, 1),
// StructType(Seq(StructField("_1", IntegerType), StructField("_2", IntegerType))))`
// to be `equal()` and so we need to disable this feature explicitly (`valueCompare=false`).
// Please note that cache by reference is still enabled depending on `needConversion`.
val pickle = new Pickler(/* useMemo = */ needConversion,
/* valueCompare = */ false)
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.map { row =>