[SPARK-16674][SQL] Avoid per-record type dispatch in JDBC when reading
## What changes were proposed in this pull request? Currently, `JDBCRDD.compute` is doing type dispatch for each row to read appropriate values. It might not have to be done like this because the schema is already kept in `JDBCRDD`. So, appropriate converters can be created first according to the schema, and then apply them to each row. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon <gurwls223@gmail.com> Closes #14313 from HyukjinKwon/SPARK-16674.
This commit is contained in:
parent
68b4020d0c
commit
7ffd99ec5f
|
@ -28,7 +28,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
|
|||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
|
||||
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
|
||||
import org.apache.spark.sql.jdbc.JdbcDialects
|
||||
import org.apache.spark.sql.sources._
|
||||
|
@ -322,43 +322,134 @@ private[sql] class JDBCRDD(
|
|||
}
|
||||
}
|
||||
|
||||
// Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that
|
||||
// we don't have to potentially poke around in the Metadata once for every
|
||||
// row.
|
||||
// Is there a better way to do this? I'd rather be using a type that
|
||||
// contains only the tags I define.
|
||||
abstract class JDBCConversion
|
||||
case object BooleanConversion extends JDBCConversion
|
||||
case object DateConversion extends JDBCConversion
|
||||
case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion
|
||||
case object DoubleConversion extends JDBCConversion
|
||||
case object FloatConversion extends JDBCConversion
|
||||
case object IntegerConversion extends JDBCConversion
|
||||
case object LongConversion extends JDBCConversion
|
||||
case object BinaryLongConversion extends JDBCConversion
|
||||
case object StringConversion extends JDBCConversion
|
||||
case object TimestampConversion extends JDBCConversion
|
||||
case object BinaryConversion extends JDBCConversion
|
||||
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion
|
||||
// A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet`
|
||||
// into a field for `MutableRow`. The last argument `Int` means the index for the
|
||||
// value to be set in the row and also used for the value to retrieve from `ResultSet`.
|
||||
private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit
|
||||
|
||||
/**
|
||||
* Maps a StructType to a type tag list.
|
||||
* Creates `JDBCValueSetter`s according to [[StructType]], which can set
|
||||
* each value from `ResultSet` to each field of [[MutableRow]] correctly.
|
||||
*/
|
||||
def getConversions(schema: StructType): Array[JDBCConversion] =
|
||||
schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))
|
||||
def makeSetters(schema: StructType): Array[JDBCValueSetter] =
|
||||
schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata))
|
||||
|
||||
private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match {
|
||||
case BooleanType =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
row.setBoolean(pos, rs.getBoolean(pos + 1))
|
||||
|
||||
case DateType =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
|
||||
val dateVal = rs.getDate(pos + 1)
|
||||
if (dateVal != null) {
|
||||
row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
|
||||
} else {
|
||||
row.update(pos, null)
|
||||
}
|
||||
|
||||
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
|
||||
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
|
||||
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
|
||||
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
|
||||
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
|
||||
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
|
||||
// retrieve it, you will get wrong result 199.99.
|
||||
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
|
||||
case DecimalType.Fixed(p, s) =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
val decimal =
|
||||
nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
|
||||
row.update(pos, decimal)
|
||||
|
||||
case DoubleType =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
row.setDouble(pos, rs.getDouble(pos + 1))
|
||||
|
||||
case FloatType =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
row.setFloat(pos, rs.getFloat(pos + 1))
|
||||
|
||||
case IntegerType =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
row.setInt(pos, rs.getInt(pos + 1))
|
||||
|
||||
case LongType if metadata.contains("binarylong") =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
val bytes = rs.getBytes(pos + 1)
|
||||
var ans = 0L
|
||||
var j = 0
|
||||
while (j < bytes.size) {
|
||||
ans = 256 * ans + (255 & bytes(j))
|
||||
j = j + 1
|
||||
}
|
||||
row.setLong(pos, ans)
|
||||
|
||||
case LongType =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
row.setLong(pos, rs.getLong(pos + 1))
|
||||
|
||||
case StringType =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
|
||||
row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
|
||||
|
||||
case TimestampType =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
val t = rs.getTimestamp(pos + 1)
|
||||
if (t != null) {
|
||||
row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
|
||||
} else {
|
||||
row.update(pos, null)
|
||||
}
|
||||
|
||||
case BinaryType =>
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
row.update(pos, rs.getBytes(pos + 1))
|
||||
|
||||
case ArrayType(et, _) =>
|
||||
val elementConversion = et match {
|
||||
case TimestampType =>
|
||||
(array: Object) =>
|
||||
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
|
||||
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
|
||||
}
|
||||
|
||||
case StringType =>
|
||||
(array: Object) =>
|
||||
array.asInstanceOf[Array[java.lang.String]]
|
||||
.map(UTF8String.fromString)
|
||||
|
||||
case DateType =>
|
||||
(array: Object) =>
|
||||
array.asInstanceOf[Array[java.sql.Date]].map { date =>
|
||||
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
|
||||
}
|
||||
|
||||
case dt: DecimalType =>
|
||||
(array: Object) =>
|
||||
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
|
||||
nullSafeConvert[java.math.BigDecimal](
|
||||
decimal, d => Decimal(d, dt.precision, dt.scale))
|
||||
}
|
||||
|
||||
case LongType if metadata.contains("binarylong") =>
|
||||
throw new IllegalArgumentException(s"Unsupported array element " +
|
||||
s"type ${dt.simpleString} based on binary")
|
||||
|
||||
case ArrayType(_, _) =>
|
||||
throw new IllegalArgumentException("Nested arrays unsupported")
|
||||
|
||||
case _ => (array: Object) => array.asInstanceOf[Array[Any]]
|
||||
}
|
||||
|
||||
(rs: ResultSet, row: MutableRow, pos: Int) =>
|
||||
val array = nullSafeConvert[Object](
|
||||
rs.getArray(pos + 1).getArray,
|
||||
array => new GenericArrayData(elementConversion.apply(array)))
|
||||
row.update(pos, array)
|
||||
|
||||
private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match {
|
||||
case BooleanType => BooleanConversion
|
||||
case DateType => DateConversion
|
||||
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
|
||||
case DoubleType => DoubleConversion
|
||||
case FloatType => FloatConversion
|
||||
case IntegerType => IntegerConversion
|
||||
case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion
|
||||
case StringType => StringConversion
|
||||
case TimestampType => TimestampConversion
|
||||
case BinaryType => BinaryConversion
|
||||
case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
|
||||
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
|
||||
}
|
||||
|
||||
|
@ -398,93 +489,15 @@ private[sql] class JDBCRDD(
|
|||
stmt.setFetchSize(fetchSize)
|
||||
val rs = stmt.executeQuery()
|
||||
|
||||
val conversions = getConversions(schema)
|
||||
val setters: Array[JDBCValueSetter] = makeSetters(schema)
|
||||
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
|
||||
|
||||
def getNext(): InternalRow = {
|
||||
if (rs.next()) {
|
||||
inputMetrics.incRecordsRead(1)
|
||||
var i = 0
|
||||
while (i < conversions.length) {
|
||||
val pos = i + 1
|
||||
conversions(i) match {
|
||||
case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
|
||||
case DateConversion =>
|
||||
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
|
||||
val dateVal = rs.getDate(pos)
|
||||
if (dateVal != null) {
|
||||
mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal))
|
||||
} else {
|
||||
mutableRow.update(i, null)
|
||||
}
|
||||
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
|
||||
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
|
||||
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
|
||||
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
|
||||
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
|
||||
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
|
||||
// retrieve it, you will get wrong result 199.99.
|
||||
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
|
||||
case DecimalConversion(p, s) =>
|
||||
val decimalVal = rs.getBigDecimal(pos)
|
||||
if (decimalVal == null) {
|
||||
mutableRow.update(i, null)
|
||||
} else {
|
||||
mutableRow.update(i, Decimal(decimalVal, p, s))
|
||||
}
|
||||
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
|
||||
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
|
||||
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
|
||||
case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
|
||||
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
|
||||
case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos)))
|
||||
case TimestampConversion =>
|
||||
val t = rs.getTimestamp(pos)
|
||||
if (t != null) {
|
||||
mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t))
|
||||
} else {
|
||||
mutableRow.update(i, null)
|
||||
}
|
||||
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
|
||||
case BinaryLongConversion =>
|
||||
val bytes = rs.getBytes(pos)
|
||||
var ans = 0L
|
||||
var j = 0
|
||||
while (j < bytes.size) {
|
||||
ans = 256 * ans + (255 & bytes(j))
|
||||
j = j + 1
|
||||
}
|
||||
mutableRow.setLong(i, ans)
|
||||
case ArrayConversion(elementConversion) =>
|
||||
val array = rs.getArray(pos).getArray
|
||||
if (array != null) {
|
||||
val data = elementConversion match {
|
||||
case TimestampConversion =>
|
||||
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
|
||||
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
|
||||
}
|
||||
case StringConversion =>
|
||||
array.asInstanceOf[Array[java.lang.String]]
|
||||
.map(UTF8String.fromString)
|
||||
case DateConversion =>
|
||||
array.asInstanceOf[Array[java.sql.Date]].map { date =>
|
||||
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
|
||||
}
|
||||
case DecimalConversion(p, s) =>
|
||||
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
|
||||
nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s))
|
||||
}
|
||||
case BinaryLongConversion =>
|
||||
throw new IllegalArgumentException(s"Unsupported array element conversion $i")
|
||||
case _: ArrayConversion =>
|
||||
throw new IllegalArgumentException("Nested arrays unsupported")
|
||||
case _ => array.asInstanceOf[Array[Any]]
|
||||
}
|
||||
mutableRow.update(i, new GenericArrayData(data))
|
||||
} else {
|
||||
mutableRow.update(i, null)
|
||||
}
|
||||
}
|
||||
while (i < setters.length) {
|
||||
setters(i).apply(rs, mutableRow, i)
|
||||
if (rs.wasNull) mutableRow.setNullAt(i)
|
||||
i = i + 1
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue