[SPARK-16674][SQL] Avoid per-record type dispatch in JDBC when reading

## What changes were proposed in this pull request?

Currently, `JDBCRDD.compute` is doing type dispatch for each row to read appropriate values.
It might not have to be done like this because the schema is already kept in `JDBCRDD`.

So, appropriate converters can be created first according to the schema, and then apply them to each row.

## How was this patch tested?

Existing tests should cover this.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #14313 from HyukjinKwon/SPARK-16674.
This commit is contained in:
hyukjinkwon 2016-07-25 19:57:47 +08:00 committed by Wenchen Fan
parent 68b4020d0c
commit 7ffd99ec5f

View file

@ -28,7 +28,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
@ -322,43 +322,134 @@ private[sql] class JDBCRDD(
}
}
// Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that
// we don't have to potentially poke around in the Metadata once for every
// row.
// Is there a better way to do this? I'd rather be using a type that
// contains only the tags I define.
abstract class JDBCConversion
case object BooleanConversion extends JDBCConversion
case object DateConversion extends JDBCConversion
case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion
case object DoubleConversion extends JDBCConversion
case object FloatConversion extends JDBCConversion
case object IntegerConversion extends JDBCConversion
case object LongConversion extends JDBCConversion
case object BinaryLongConversion extends JDBCConversion
case object StringConversion extends JDBCConversion
case object TimestampConversion extends JDBCConversion
case object BinaryConversion extends JDBCConversion
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion
// A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet`
// into a field for `MutableRow`. The last argument `Int` means the index for the
// value to be set in the row and also used for the value to retrieve from `ResultSet`.
private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit
/**
* Maps a StructType to a type tag list.
* Creates `JDBCValueSetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[MutableRow]] correctly.
*/
def getConversions(schema: StructType): Array[JDBCConversion] =
schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))
def makeSetters(schema: StructType): Array[JDBCValueSetter] =
schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata))
private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match {
case BooleanType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
case DateType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos + 1)
if (dateVal != null) {
row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
} else {
row.update(pos, null)
}
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
case DecimalType.Fixed(p, s) =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val decimal =
nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
row.update(pos, decimal)
case DoubleType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setDouble(pos, rs.getDouble(pos + 1))
case FloatType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setFloat(pos, rs.getFloat(pos + 1))
case IntegerType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setInt(pos, rs.getInt(pos + 1))
case LongType if metadata.contains("binarylong") =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val bytes = rs.getBytes(pos + 1)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1
}
row.setLong(pos, ans)
case LongType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setLong(pos, rs.getLong(pos + 1))
case StringType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
case TimestampType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val t = rs.getTimestamp(pos + 1)
if (t != null) {
row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
} else {
row.update(pos, null)
}
case BinaryType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.update(pos, rs.getBytes(pos + 1))
case ArrayType(et, _) =>
val elementConversion = et match {
case TimestampType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringType =>
(array: Object) =>
array.asInstanceOf[Array[java.lang.String]]
.map(UTF8String.fromString)
case DateType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case dt: DecimalType =>
(array: Object) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](
decimal, d => Decimal(d, dt.precision, dt.scale))
}
case LongType if metadata.contains("binarylong") =>
throw new IllegalArgumentException(s"Unsupported array element " +
s"type ${dt.simpleString} based on binary")
case ArrayType(_, _) =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => (array: Object) => array.asInstanceOf[Array[Any]]
}
(rs: ResultSet, row: MutableRow, pos: Int) =>
val array = nullSafeConvert[Object](
rs.getArray(pos + 1).getArray,
array => new GenericArrayData(elementConversion.apply(array)))
row.update(pos, array)
private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
}
@ -398,93 +489,15 @@ private[sql] class JDBCRDD(
stmt.setFetchSize(fetchSize)
val rs = stmt.executeQuery()
val conversions = getConversions(schema)
val setters: Array[JDBCValueSetter] = makeSetters(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
while (i < conversions.length) {
val pos = i + 1
conversions(i) match {
case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
case DateConversion =>
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos)
if (dateVal != null) {
mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal))
} else {
mutableRow.update(i, null)
}
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
case DecimalConversion(p, s) =>
val decimalVal = rs.getBigDecimal(pos)
if (decimalVal == null) {
mutableRow.update(i, null)
} else {
mutableRow.update(i, Decimal(decimalVal, p, s))
}
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos)))
case TimestampConversion =>
val t = rs.getTimestamp(pos)
if (t != null) {
mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t))
} else {
mutableRow.update(i, null)
}
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
case BinaryLongConversion =>
val bytes = rs.getBytes(pos)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1
}
mutableRow.setLong(i, ans)
case ArrayConversion(elementConversion) =>
val array = rs.getArray(pos).getArray
if (array != null) {
val data = elementConversion match {
case TimestampConversion =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringConversion =>
array.asInstanceOf[Array[java.lang.String]]
.map(UTF8String.fromString)
case DateConversion =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case DecimalConversion(p, s) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s))
}
case BinaryLongConversion =>
throw new IllegalArgumentException(s"Unsupported array element conversion $i")
case _: ArrayConversion =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => array.asInstanceOf[Array[Any]]
}
mutableRow.update(i, new GenericArrayData(data))
} else {
mutableRow.update(i, null)
}
}
while (i < setters.length) {
setters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}