[SPARK-15985][SQL] Eliminate redundant cast from an array without null or a map without null

## What changes were proposed in this pull request?

This PR eliminates redundant cast from an `ArrayType` with `containsNull = false` or a `MapType` with `containsNull = false`.

For example, in `ArrayType` case, current implementation leaves a cast `cast(value#63 as array<double>).toDoubleArray`. However, we can eliminate `cast(value#63 as array<double>)` if we know `value#63` does not include `null`. This PR apply this elimination for `ArrayType` and `MapType` in `SimplifyCasts` at a plan optimization phase.

In summary, we got 1.2-1.3x performance improvements over the code before applying this PR.
Here are performance results of benchmark programs:
```
  test("Read array in Dataset") {
    import sparkSession.implicits._

    val iters = 5
    val n = 1024 * 1024
    val rows = 15

    val benchmark = new Benchmark("Read primnitive array", n)

    val rand = new Random(511)
    val intDS = sparkSession.sparkContext.parallelize(0 until rows, 1)
      .map(i => Array.tabulate(n)(i => i)).toDS()
    intDS.count() // force to create ds
    val lastElement = n - 1
    val randElement = rand.nextInt(lastElement)

    benchmark.addCase(s"Read int array in Dataset", numIters = iters)(iter => {
      val idx0 = randElement
      val idx1 = lastElement
      intDS.map(a => a(0) + a(idx0) + a(idx1)).collect
    })

    val doubleDS = sparkSession.sparkContext.parallelize(0 until rows, 1)
      .map(i => Array.tabulate(n)(i => i.toDouble)).toDS()
    doubleDS.count() // force to create ds

    benchmark.addCase(s"Read double array in Dataset", numIters = iters)(iter => {
      val idx0 = randElement
      val idx1 = lastElement
      doubleDS.map(a => a(0) + a(idx0) + a(idx1)).collect
    })

    benchmark.run()
  }

Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.4
Intel(R) Core(TM) i5-5257U CPU  2.70GHz

without this PR
Read primnitive array:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Read int array in Dataset                      525 /  690          2.0         500.9       1.0X
Read double array in Dataset                   947 / 1209          1.1         902.7       0.6X

with this PR
Read primnitive array:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Read int array in Dataset                      400 /  492          2.6         381.5       1.0X
Read double array in Dataset                   788 /  870          1.3         751.4       0.5X
```

An example program that originally caused this performance issue.
```
val ds = Seq(Array(1.0, 2.0, 3.0), Array(4.0, 5.0, 6.0)).toDS()
val ds2 = ds.map(p => {
     var s = 0.0
     for (i <- 0 to 2) { s += p(i) }
     s
   })
ds2.show
ds2.explain(true)
```

Plans before this PR
```
== Parsed Logical Plan ==
'SerializeFromObject [input[0, double, true] AS value#68]
+- 'MapElements <function1>, obj#67: double
   +- 'DeserializeToObject unresolveddeserializer(upcast(getcolumnbyordinal(0, ArrayType(DoubleType,false)), ArrayType(DoubleType,false), - root class: "scala.Array").toDoubleArray), obj#66: [D
      +- LocalRelation [value#63]

== Analyzed Logical Plan ==
value: double
SerializeFromObject [input[0, double, true] AS value#68]
+- MapElements <function1>, obj#67: double
   +- DeserializeToObject cast(value#63 as array<double>).toDoubleArray, obj#66: [D
      +- LocalRelation [value#63]

== Optimized Logical Plan ==
SerializeFromObject [input[0, double, true] AS value#68]
+- MapElements <function1>, obj#67: double
   +- DeserializeToObject cast(value#63 as array<double>).toDoubleArray, obj#66: [D
      +- LocalRelation [value#63]

== Physical Plan ==
*SerializeFromObject [input[0, double, true] AS value#68]
+- *MapElements <function1>, obj#67: double
   +- *DeserializeToObject cast(value#63 as array<double>).toDoubleArray, obj#66: [D
      +- LocalTableScan [value#63]
```

Plans after this PR
```
== Parsed Logical Plan ==
'SerializeFromObject [input[0, double, true] AS value#6]
+- 'MapElements <function1>, obj#5: double
   +- 'DeserializeToObject unresolveddeserializer(upcast(getcolumnbyordinal(0, ArrayType(DoubleType,false)), ArrayType(DoubleType,false), - root class: "scala.Array").toDoubleArray), obj#4: [D
      +- LocalRelation [value#1]

== Analyzed Logical Plan ==
value: double
SerializeFromObject [input[0, double, true] AS value#6]
+- MapElements <function1>, obj#5: double
   +- DeserializeToObject cast(value#1 as array<double>).toDoubleArray, obj#4: [D
      +- LocalRelation [value#1]

== Optimized Logical Plan ==
SerializeFromObject [input[0, double, true] AS value#6]
+- MapElements <function1>, obj#5: double
   +- DeserializeToObject value#1.toDoubleArray, obj#4: [D
      +- LocalRelation [value#1]

== Physical Plan ==
*SerializeFromObject [input[0, double, true] AS value#6]
+- *MapElements <function1>, obj#5: double
   +- *DeserializeToObject value#1.toDoubleArray, obj#4: [D
      +- LocalTableScan [value#1]
```

## How was this patch tested?

Tested by new test cases in `SimplifyCastsSuite`

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #13704 from kiszk/SPARK-15985.
This commit is contained in:
Kazuaki Ishizaki 2016-08-31 12:40:53 +08:00 committed by Wenchen Fan
parent 231f973295
commit d92cd227cf
3 changed files with 76 additions and 0 deletions

View file

@ -242,6 +242,9 @@ package object dsl {
def array(dataType: DataType): AttributeReference =
AttributeReference(s, ArrayType(dataType), nullable = true)()
def array(arrayType: ArrayType): AttributeReference =
AttributeReference(s, arrayType)()
/** Creates a new AttributeReference of type map */
def map(keyType: DataType, valueType: DataType): AttributeReference =
map(MapType(keyType, valueType))

View file

@ -475,6 +475,12 @@ case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
object SimplifyCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Cast(e, dataType) if e.dataType == dataType => e
case c @ Cast(e, dataType) => (e.dataType, dataType) match {
case (ArrayType(from, false), ArrayType(to, true)) if from == to => e
case (MapType(fromKey, fromValue, false), MapType(toKey, toValue, true))
if fromKey == toKey && fromValue == toValue => e
case _ => c
}
}
}

View file

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._
class SimplifyCastsSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil
}
test("non-nullable element array to nullable element array cast") {
val input = LocalRelation('a.array(ArrayType(IntegerType, false)))
val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze
val optimized = Optimize.execute(plan)
val expected = input.select('a.as("casted")).analyze
comparePlans(optimized, expected)
}
test("nullable element to non-nullable element array cast") {
val input = LocalRelation('a.array(ArrayType(IntegerType, true)))
val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, plan)
}
test("non-nullable value map to nullable value map cast") {
val input = LocalRelation('m.map(MapType(StringType, StringType, false)))
val plan = input.select('m.cast(MapType(StringType, StringType, true))
.as("casted")).analyze
val optimized = Optimize.execute(plan)
val expected = input.select('m.as("casted")).analyze
comparePlans(optimized, expected)
}
test("nullable value map to non-nullable value map cast") {
val input = LocalRelation('m.map(MapType(StringType, StringType, true)))
val plan = input.select('m.cast(MapType(StringType, StringType, false))
.as("casted")).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, plan)
}
}