[SPARK-12404][SQL] Ensure objects passed to StaticInvoke is Serializable
Now `StaticInvoke` receives `Any` as a object and `StaticInvoke` can be serialized but sometimes the object passed is not serializable. For example, following code raises Exception because `RowEncoder#extractorsFor` invoked indirectly makes `StaticInvoke`. ``` case class TimestampContainer(timestamp: java.sql.Timestamp) val rdd = sc.parallelize(1 to 2).map(_ => TimestampContainer(System.currentTimeMillis)) val df = rdd.toDF val ds = df.as[TimestampContainer] val rdd2 = ds.rdd <----------------- invokes extractorsFor indirectory ``` I'll add test cases. Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp> Author: Michael Armbrust <michael@databricks.com> Closes #10357 from sarutak/SPARK-12404.
This commit is contained in:
parent
41ee7c57ab
commit
6eba655259
|
@ -194,7 +194,7 @@ object JavaTypeInference {
|
|||
|
||||
case c if c == classOf[java.sql.Date] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
ObjectType(c),
|
||||
"toJavaDate",
|
||||
getPath :: Nil,
|
||||
|
@ -202,7 +202,7 @@ object JavaTypeInference {
|
|||
|
||||
case c if c == classOf[java.sql.Timestamp] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
ObjectType(c),
|
||||
"toJavaTimestamp",
|
||||
getPath :: Nil,
|
||||
|
@ -276,7 +276,7 @@ object JavaTypeInference {
|
|||
ObjectType(classOf[Array[Any]]))
|
||||
|
||||
StaticInvoke(
|
||||
ArrayBasedMapData,
|
||||
ArrayBasedMapData.getClass,
|
||||
ObjectType(classOf[JMap[_, _]]),
|
||||
"toJavaMap",
|
||||
keyData :: valueData :: Nil)
|
||||
|
@ -341,21 +341,21 @@ object JavaTypeInference {
|
|||
|
||||
case c if c == classOf[java.sql.Timestamp] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
TimestampType,
|
||||
"fromJavaTimestamp",
|
||||
inputObject :: Nil)
|
||||
|
||||
case c if c == classOf[java.sql.Date] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
DateType,
|
||||
"fromJavaDate",
|
||||
inputObject :: Nil)
|
||||
|
||||
case c if c == classOf[java.math.BigDecimal] =>
|
||||
StaticInvoke(
|
||||
Decimal,
|
||||
Decimal.getClass,
|
||||
DecimalType.SYSTEM_DEFAULT,
|
||||
"apply",
|
||||
inputObject :: Nil)
|
||||
|
|
|
@ -223,7 +223,7 @@ object ScalaReflection extends ScalaReflection {
|
|||
|
||||
case t if t <:< localTypeOf[java.sql.Date] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
ObjectType(classOf[java.sql.Date]),
|
||||
"toJavaDate",
|
||||
getPath :: Nil,
|
||||
|
@ -231,7 +231,7 @@ object ScalaReflection extends ScalaReflection {
|
|||
|
||||
case t if t <:< localTypeOf[java.sql.Timestamp] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
ObjectType(classOf[java.sql.Timestamp]),
|
||||
"toJavaTimestamp",
|
||||
getPath :: Nil,
|
||||
|
@ -287,7 +287,7 @@ object ScalaReflection extends ScalaReflection {
|
|||
ObjectType(classOf[Array[Any]]))
|
||||
|
||||
StaticInvoke(
|
||||
scala.collection.mutable.WrappedArray,
|
||||
scala.collection.mutable.WrappedArray.getClass,
|
||||
ObjectType(classOf[Seq[_]]),
|
||||
"make",
|
||||
arrayData :: Nil)
|
||||
|
@ -315,7 +315,7 @@ object ScalaReflection extends ScalaReflection {
|
|||
ObjectType(classOf[Array[Any]]))
|
||||
|
||||
StaticInvoke(
|
||||
ArrayBasedMapData,
|
||||
ArrayBasedMapData.getClass,
|
||||
ObjectType(classOf[Map[_, _]]),
|
||||
"toScalaMap",
|
||||
keyData :: valueData :: Nil)
|
||||
|
@ -548,28 +548,28 @@ object ScalaReflection extends ScalaReflection {
|
|||
|
||||
case t if t <:< localTypeOf[java.sql.Timestamp] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
TimestampType,
|
||||
"fromJavaTimestamp",
|
||||
inputObject :: Nil)
|
||||
|
||||
case t if t <:< localTypeOf[java.sql.Date] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
DateType,
|
||||
"fromJavaDate",
|
||||
inputObject :: Nil)
|
||||
|
||||
case t if t <:< localTypeOf[BigDecimal] =>
|
||||
StaticInvoke(
|
||||
Decimal,
|
||||
Decimal.getClass,
|
||||
DecimalType.SYSTEM_DEFAULT,
|
||||
"apply",
|
||||
inputObject :: Nil)
|
||||
|
||||
case t if t <:< localTypeOf[java.math.BigDecimal] =>
|
||||
StaticInvoke(
|
||||
Decimal,
|
||||
Decimal.getClass,
|
||||
DecimalType.SYSTEM_DEFAULT,
|
||||
"apply",
|
||||
inputObject :: Nil)
|
||||
|
|
|
@ -61,21 +61,21 @@ object RowEncoder {
|
|||
|
||||
case TimestampType =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
TimestampType,
|
||||
"fromJavaTimestamp",
|
||||
inputObject :: Nil)
|
||||
|
||||
case DateType =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
DateType,
|
||||
"fromJavaDate",
|
||||
inputObject :: Nil)
|
||||
|
||||
case _: DecimalType =>
|
||||
StaticInvoke(
|
||||
Decimal,
|
||||
Decimal.getClass,
|
||||
DecimalType.SYSTEM_DEFAULT,
|
||||
"apply",
|
||||
inputObject :: Nil)
|
||||
|
@ -172,14 +172,14 @@ object RowEncoder {
|
|||
|
||||
case TimestampType =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
ObjectType(classOf[java.sql.Timestamp]),
|
||||
"toJavaTimestamp",
|
||||
input :: Nil)
|
||||
|
||||
case DateType =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateTimeUtils.getClass,
|
||||
ObjectType(classOf[java.sql.Date]),
|
||||
"toJavaDate",
|
||||
input :: Nil)
|
||||
|
@ -197,7 +197,7 @@ object RowEncoder {
|
|||
"array",
|
||||
ObjectType(classOf[Array[_]]))
|
||||
StaticInvoke(
|
||||
scala.collection.mutable.WrappedArray,
|
||||
scala.collection.mutable.WrappedArray.getClass,
|
||||
ObjectType(classOf[Seq[_]]),
|
||||
"make",
|
||||
arrayData :: Nil)
|
||||
|
@ -210,7 +210,7 @@ object RowEncoder {
|
|||
val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType))
|
||||
|
||||
StaticInvoke(
|
||||
ArrayBasedMapData,
|
||||
ArrayBasedMapData.getClass,
|
||||
ObjectType(classOf[Map[_, _]]),
|
||||
"toScalaMap",
|
||||
keyData :: valueData :: Nil)
|
||||
|
|
|
@ -42,16 +42,14 @@ import org.apache.spark.sql.types._
|
|||
* of calling the function.
|
||||
*/
|
||||
case class StaticInvoke(
|
||||
staticObject: Any,
|
||||
staticObject: Class[_],
|
||||
dataType: DataType,
|
||||
functionName: String,
|
||||
arguments: Seq[Expression] = Nil,
|
||||
propagateNull: Boolean = true) extends Expression {
|
||||
|
||||
val objectName = staticObject match {
|
||||
case c: Class[_] => c.getName
|
||||
case other => other.getClass.getName.stripSuffix("$")
|
||||
}
|
||||
val objectName = staticObject.getName.stripSuffix("$")
|
||||
|
||||
override def nullable: Boolean = true
|
||||
override def children: Seq[Expression] = arguments
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.apache.spark.sql.expressions.Aggregator;
|
|||
import org.apache.spark.sql.test.TestSQLContext;
|
||||
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericRow;
|
||||
import org.apache.spark.sql.types.DecimalType;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
import static org.apache.spark.sql.functions.*;
|
||||
|
@ -608,6 +609,44 @@ public class JavaDatasetSuite implements Serializable {
|
|||
}
|
||||
}
|
||||
|
||||
public class SimpleJavaBean2 implements Serializable {
|
||||
private Timestamp a;
|
||||
private Date b;
|
||||
private java.math.BigDecimal c;
|
||||
|
||||
public Timestamp getA() { return a; }
|
||||
|
||||
public void setA(Timestamp a) { this.a = a; }
|
||||
|
||||
public Date getB() { return b; }
|
||||
|
||||
public void setB(Date b) { this.b = b; }
|
||||
|
||||
public java.math.BigDecimal getC() { return c; }
|
||||
|
||||
public void setC(java.math.BigDecimal c) { this.c = c; }
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
SimpleJavaBean that = (SimpleJavaBean) o;
|
||||
|
||||
if (!a.equals(that.a)) return false;
|
||||
if (!b.equals(that.b)) return false;
|
||||
return c.equals(that.c);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = a.hashCode();
|
||||
result = 31 * result + b.hashCode();
|
||||
result = 31 * result + c.hashCode();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
public class NestedJavaBean implements Serializable {
|
||||
private SimpleJavaBean a;
|
||||
|
||||
|
@ -689,4 +728,17 @@ public class JavaDatasetSuite implements Serializable {
|
|||
.as(Encoders.bean(SimpleJavaBean.class));
|
||||
Assert.assertEquals(data, ds3.collectAsList());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJavaBeanEncoder2() {
|
||||
// This is a regression test of SPARK-12404
|
||||
OuterScopes.addOuterScope(this);
|
||||
SimpleJavaBean2 obj = new SimpleJavaBean2();
|
||||
obj.setA(new Timestamp(0));
|
||||
obj.setB(new Date(0));
|
||||
obj.setC(java.math.BigDecimal.valueOf(1));
|
||||
Dataset<SimpleJavaBean2> ds =
|
||||
context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
|
||||
ds.collect();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql
|
||||
|
||||
import java.io.{ObjectInput, ObjectOutput, Externalizable}
|
||||
import java.sql.{Date, Timestamp}
|
||||
|
||||
import scala.language.postfixOps
|
||||
|
||||
|
@ -42,6 +43,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
1, 1, 1)
|
||||
}
|
||||
|
||||
|
||||
test("SPARK-12404: Datatype Helper Serializablity") {
|
||||
val ds = sparkContext.parallelize((
|
||||
new Timestamp(0),
|
||||
new Date(0),
|
||||
java.math.BigDecimal.valueOf(1),
|
||||
scala.math.BigDecimal(1)) :: Nil).toDS()
|
||||
|
||||
ds.collect()
|
||||
}
|
||||
|
||||
test("collect, first, and take should use encoders for serialization") {
|
||||
val item = NonSerializableCaseClass("abcd")
|
||||
val ds = Seq(item).toDS()
|
||||
|
|
Loading…
Reference in a new issue