[SPARK-34545][SQL] Fix issues with valueCompare feature of pyrolite
### What changes were proposed in this pull request? pyrolite 4.21 introduced and enabled value comparison by default (`valueCompare=true`) during object memoization and serialization: https://github.com/irmen/Pyrolite/blob/pyrolite-4.21/java/src/main/java/net/razorvine/pickle/Pickler.java#L112-L122 This change has undesired effect when we serialize a row (actually `GenericRowWithSchema`) to be passed to python: https://github.com/apache/spark/blob/branch-3.0/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala#L60. A simple example is that ``` new GenericRowWithSchema(Array(1.0, 1.0), StructType(Seq(StructField("_1", DoubleType), StructField("_2", DoubleType)))) ``` and ``` new GenericRowWithSchema(Array(1, 1), StructType(Seq(StructField("_1", IntegerType), StructField("_2", IntegerType)))) ``` are currently equal and the second instance is replaced to the short code of the first one during serialization. ### Why are the changes needed? The above can cause nasty issues like the one in https://issues.apache.org/jira/browse/SPARK-34545 description: ``` >>> from pyspark.sql.functions import udf >>> from pyspark.sql.types import * >>> >>> def udf1(data_type): def u1(e): return e[0] return udf(u1, data_type) >>> >>> df = spark.createDataFrame([((1.0, 1.0), (1, 1))], ['c1', 'c2']) >>> >>> df = df.withColumn("c3", udf1(DoubleType())("c1")) >>> df = df.withColumn("c4", udf1(IntegerType())("c2")) >>> >>> df.select("c3").show() +---+ | c3| +---+ |1.0| +---+ >>> df.select("c4").show() +---+ | c4| +---+ | 1| +---+ >>> df.select("c3", "c4").show() +---+----+ | c3| c4| +---+----+ |1.0|null| +---+----+ ``` This is because during serialization from JVM to Python `GenericRowWithSchema(1.0, 1.0)` (`c1`) is memoized first and when `GenericRowWithSchema(1, 1)` (`c2`) comes next, it is replaced to some short code of the `c1` (instead of serializing `c2` out) as they are `equal()`. The python functions then runs but the return type of `c4` is expected to be `IntegerType` and if a different type (`DoubleType`) comes back from python then it is discarded: https://github.com/apache/spark/blob/branch-3.0/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala#L108-L113 After this PR: ``` >>> df.select("c3", "c4").show() +---+---+ | c3| c4| +---+---+ |1.0| 1| +---+---+ ``` ### Does this PR introduce _any_ user-facing change? Yes, fixes a correctness issue. ### How was this patch tested? Added new UT + manual tests. Closes #31682 from peter-toth/SPARK-34545-fix-row-comparison. Authored-by: Peter Toth <peter.toth@gmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
9ec8696f11
commit
ab8a9a0ceb
|
@ -78,7 +78,8 @@ private[spark] object SerDeUtil extends Logging {
|
|||
* Choose batch size based on size of objects
|
||||
*/
|
||||
private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
|
||||
private val pickle = new Pickler()
|
||||
private val pickle = new Pickler(/* useMemo = */ true,
|
||||
/* valueCompare = */ false)
|
||||
private var batch = 1
|
||||
private val buffer = new mutable.ArrayBuffer[Any]
|
||||
|
||||
|
@ -131,7 +132,8 @@ private[spark] object SerDeUtil extends Logging {
|
|||
}
|
||||
|
||||
private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
|
||||
val pickle = new Pickler
|
||||
val pickle = new Pickler(/* useMemo = */ true,
|
||||
/* valueCompare = */ false)
|
||||
val kt = Try {
|
||||
pickle.dumps(t._1)
|
||||
}
|
||||
|
@ -182,7 +184,8 @@ private[spark] object SerDeUtil extends Logging {
|
|||
if (batchSize == 0) {
|
||||
new AutoBatchedPickler(cleaned)
|
||||
} else {
|
||||
val pickle = new Pickler
|
||||
val pickle = new Pickler(/* useMemo = */ true,
|
||||
/* valueCompare = */ false)
|
||||
cleaned.grouped(batchSize).map(batched => pickle.dumps(batched.asJava))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1313,8 +1313,10 @@ private[spark] abstract class SerDeBase {
|
|||
def dumps(obj: AnyRef): Array[Byte] = {
|
||||
obj match {
|
||||
// Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834.
|
||||
case array: Array[_] => new Pickler().dumps(array.toSeq.asJava)
|
||||
case _ => new Pickler().dumps(obj)
|
||||
case array: Array[_] => new Pickler(/* useMemo = */ true,
|
||||
/* valueCompare = */ false).dumps(array.toSeq.asJava)
|
||||
case _ => new Pickler(/* useMemo = */ true,
|
||||
/* valueCompare = */ false).dumps(obj)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -674,6 +674,17 @@ class UDFTests(ReusedSQLTestCase):
|
|||
self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution()
|
||||
.withCachedData().getClass().getSimpleName(), 'InMemoryRelation')
|
||||
|
||||
# SPARK-34545
|
||||
def test_udf_input_serialization_valuecompare_disabled(self):
|
||||
def f(e):
|
||||
return e[0]
|
||||
|
||||
df = self.spark.createDataFrame([((1.0, 1.0), (1, 1))], ['c1', 'c2'])
|
||||
result = df.select("*", udf(f, DoubleType())("c1").alias('c3'),
|
||||
udf(f, IntegerType())("c2").alias('c4'))
|
||||
self.assertEqual(result.collect(),
|
||||
[Row(c1=Row(_1=1.0, _2=1.0), c2=Row(_1=1, _2=1), c3=1.0, c4=1)])
|
||||
|
||||
|
||||
class UDFInitializationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
|
|
|
@ -46,7 +46,18 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
|
|||
val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
|
||||
|
||||
// enable memo iff we serialize the row with schema (schema and class should be memorized)
|
||||
val pickle = new Pickler(needConversion)
|
||||
// pyrolite 4.21+ can lookup objects in its cache by value, but `GenericRowWithSchema` objects,
|
||||
// that we pass from JVM to Python, don't define their `equals()` to take the type of the
|
||||
// values or the schema of the row into account. This causes like
|
||||
// `GenericRowWithSchema(Array(1.0, 1.0),
|
||||
// StructType(Seq(StructField("_1", DoubleType), StructField("_2", DoubleType))))`
|
||||
// and
|
||||
// `GenericRowWithSchema(Array(1, 1),
|
||||
// StructType(Seq(StructField("_1", IntegerType), StructField("_2", IntegerType))))`
|
||||
// to be `equal()` and so we need to disable this feature explicitly (`valueCompare=false`).
|
||||
// Please note that cache by reference is still enabled depending on `needConversion`.
|
||||
val pickle = new Pickler(/* useMemo = */ needConversion,
|
||||
/* valueCompare = */ false)
|
||||
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
|
||||
// For each row, add it to the queue.
|
||||
val inputIterator = iter.map { row =>
|
||||
|
|
Loading…
Reference in a new issue