[SPARK-1371][WIP] Compression support for Spark SQL in-memory columnar storage

JIRA issue: [SPARK-1373](https://issues.apache.org/jira/browse/SPARK-1373)

(Although tagged as WIP, this PR is structurally complete. The only things left unimplemented are 3 more compression algorithms: `BooleanBitSet`, `IntDelta` and `LongDelta`, which are trivial to add later in this or another separate PR.)

This PR contains compression support for Spark SQL in-memory columnar storage. Main interfaces include:

*   `CompressionScheme`

    Each `CompressionScheme` represents a concrete compression algorithm, which basically consists of an `Encoder` for compression and a `Decoder` for decompression. Algorithms implemented include:

    * `RunLengthEncoding`
    * `DictionaryEncoding`

    Algorithms to be implemented include:

    * `BooleanBitSet`
    * `IntDelta`
    * `LongDelta`

*   `CompressibleColumnBuilder`

    A stackable `ColumnBuilder` trait used to build byte buffers for compressible columns.  A best `CompressionScheme` that exhibits lowest compression ratio is chosen for each column according to statistical information gathered while elements are appended into the `ColumnBuilder`. However, if no `CompressionScheme` can achieve a compression ratio better than 80%, no compression will be done for this column to save CPU time.

    Memory layout of the final byte buffer is showed below:

    ```
     .--------------------------- Column type ID (4 bytes)
     |   .----------------------- Null count N (4 bytes)
     |   |   .------------------- Null positions (4 x N bytes, empty if null count is zero)
     |   |   |     .------------- Compression scheme ID (4 bytes)
     |   |   |     |   .--------- Compressed non-null elements
     V   V   V     V   V
    +---+---+-----+---+---------+
    |   |   | ... |   | ... ... |
    +---+---+-----+---+---------+
     \-----------/ \-----------/
        header         body
    ```

*   `CompressibleColumnAccessor`

    A stackable `ColumnAccessor` trait used to iterate (possibly) compressed data column.

*   `ColumnStats`

    Used to collect statistical information while loading data into in-memory columnar table. Optimizations like partition pruning rely on this information.

    Strictly speaking, `ColumnStats` related code is not part of the compression support. It's contained in this PR to ensure and validate the row-based API design (which is used to avoid boxing/unboxing cost whenever possible).

A major refactoring change since PR #205 is:

* Refactored all getter/setter methods for primitive types in various places into `ColumnType` classes to remove duplicated code.

Author: Cheng Lian <lian.cs.zju@gmail.com>

Closes #285 from liancheng/memColumnarCompression and squashes the following commits:

ed71bbd [Cheng Lian] Addressed all PR comments by @marmbrus
d3a4fa9 [Cheng Lian] Removed Ordering[T] in ColumnStats for better performance
5034453 [Cheng Lian] Bug fix, more tests, and more refactoring
c298b76 [Cheng Lian] Test suites refactored
2780d6a [Cheng Lian] [WIP] in-memory columnar compression support
211331c [Cheng Lian] WIP: in-memory columnar compression support
85cc59b [Cheng Lian] Refactored ColumnAccessors & ColumnBuilders to remove duplicate code
This commit is contained in:
Cheng Lian 2014-04-02 12:47:22 -07:00 committed by Patrick Wendell
parent 78236334e4
commit 1faa579711
21 changed files with 1643 additions and 407 deletions

View file

@ -21,7 +21,7 @@ import java.nio.{ByteOrder, ByteBuffer}
import org.apache.spark.sql.catalyst.types.{BinaryType, NativeType, DataType}
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
/**
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
@ -41,121 +41,66 @@ private[sql] trait ColumnAccessor {
protected def underlyingBuffer: ByteBuffer
}
private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](buffer: ByteBuffer)
private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
protected val buffer: ByteBuffer,
protected val columnType: ColumnType[T, JvmType])
extends ColumnAccessor {
protected def initialize() {}
def columnType: ColumnType[T, JvmType]
def hasNext = buffer.hasRemaining
def extractTo(row: MutableRow, ordinal: Int) {
doExtractTo(row, ordinal)
columnType.setField(row, ordinal, extractSingle(buffer))
}
protected def doExtractTo(row: MutableRow, ordinal: Int)
def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer)
protected def underlyingBuffer = buffer
}
private[sql] abstract class NativeColumnAccessor[T <: NativeType](
buffer: ByteBuffer,
val columnType: NativeColumnType[T])
extends BasicColumnAccessor[T, T#JvmType](buffer)
override protected val buffer: ByteBuffer,
override protected val columnType: NativeColumnType[T])
extends BasicColumnAccessor(buffer, columnType)
with NullableColumnAccessor
with CompressibleColumnAccessor[T]
private[sql] class BooleanColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, BOOLEAN) {
override protected def doExtractTo(row: MutableRow, ordinal: Int) {
row.setBoolean(ordinal, columnType.extract(buffer))
}
}
extends NativeColumnAccessor(buffer, BOOLEAN)
private[sql] class IntColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, INT) {
override protected def doExtractTo(row: MutableRow, ordinal: Int) {
row.setInt(ordinal, columnType.extract(buffer))
}
}
extends NativeColumnAccessor(buffer, INT)
private[sql] class ShortColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, SHORT) {
override protected def doExtractTo(row: MutableRow, ordinal: Int) {
row.setShort(ordinal, columnType.extract(buffer))
}
}
extends NativeColumnAccessor(buffer, SHORT)
private[sql] class LongColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, LONG) {
override protected def doExtractTo(row: MutableRow, ordinal: Int) {
row.setLong(ordinal, columnType.extract(buffer))
}
}
extends NativeColumnAccessor(buffer, LONG)
private[sql] class ByteColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, BYTE) {
override protected def doExtractTo(row: MutableRow, ordinal: Int) {
row.setByte(ordinal, columnType.extract(buffer))
}
}
extends NativeColumnAccessor(buffer, BYTE)
private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, DOUBLE) {
override protected def doExtractTo(row: MutableRow, ordinal: Int) {
row.setDouble(ordinal, columnType.extract(buffer))
}
}
extends NativeColumnAccessor(buffer, DOUBLE)
private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, FLOAT) {
override protected def doExtractTo(row: MutableRow, ordinal: Int) {
row.setFloat(ordinal, columnType.extract(buffer))
}
}
extends NativeColumnAccessor(buffer, FLOAT)
private[sql] class StringColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, STRING) {
override protected def doExtractTo(row: MutableRow, ordinal: Int) {
row.setString(ordinal, columnType.extract(buffer))
}
}
extends NativeColumnAccessor(buffer, STRING)
private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer)
with NullableColumnAccessor {
def columnType = BINARY
override protected def doExtractTo(row: MutableRow, ordinal: Int) {
row(ordinal) = columnType.extract(buffer)
}
}
extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY)
with NullableColumnAccessor
private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[DataType, Array[Byte]](buffer)
with NullableColumnAccessor {
def columnType = GENERIC
override protected def doExtractTo(row: MutableRow, ordinal: Int) {
val serialized = columnType.extract(buffer)
row(ordinal) = SparkSqlSerializer.deserialize[Any](serialized)
}
}
extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC)
with NullableColumnAccessor
private[sql] object ColumnAccessor {
def apply(b: ByteBuffer): ColumnAccessor = {
// The first 4 bytes in the buffer indicates the column type.
val buffer = b.duplicate().order(ByteOrder.nativeOrder())
def apply(buffer: ByteBuffer): ColumnAccessor = {
// The first 4 bytes in the buffer indicate the column type.
val columnTypeId = buffer.getInt()
columnTypeId match {

View file

@ -22,7 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.ColumnBuilder._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder}
private[sql] trait ColumnBuilder {
/**
@ -30,37 +30,44 @@ private[sql] trait ColumnBuilder {
*/
def initialize(initialSize: Int, columnName: String = "")
/**
* Appends `row(ordinal)` to the column builder.
*/
def appendFrom(row: Row, ordinal: Int)
/**
* Column statistics information
*/
def columnStats: ColumnStats[_, _]
/**
* Returns the final columnar byte buffer.
*/
def build(): ByteBuffer
}
private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends ColumnBuilder {
private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
val columnStats: ColumnStats[T, JvmType],
val columnType: ColumnType[T, JvmType])
extends ColumnBuilder {
protected var columnName: String = _
private var columnName: String = _
protected var buffer: ByteBuffer = _
def columnType: ColumnType[T, JvmType]
override def initialize(initialSize: Int, columnName: String = "") = {
val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize
this.columnName = columnName
buffer = ByteBuffer.allocate(4 + 4 + size * columnType.defaultSize)
// Reserves 4 bytes for column type ID
buffer = ByteBuffer.allocate(4 + size * columnType.defaultSize)
buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId)
}
// Have to give a concrete implementation to make mixin possible
override def appendFrom(row: Row, ordinal: Int) {
doAppendFrom(row, ordinal)
}
// Concrete `ColumnBuilder`s can override this method to append values
protected def doAppendFrom(row: Row, ordinal: Int)
// Helper method to append primitive values (to avoid boxing cost)
protected def appendValue(v: JvmType) {
buffer = ensureFreeSpace(buffer, columnType.actualSize(v))
columnType.append(v, buffer)
val field = columnType.getField(row, ordinal)
buffer = ensureFreeSpace(buffer, columnType.actualSize(field))
columnType.append(field, buffer)
}
override def build() = {
@ -69,83 +76,39 @@ private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends C
}
}
private[sql] abstract class NativeColumnBuilder[T <: NativeType](
val columnType: NativeColumnType[T])
extends BasicColumnBuilder[T, T#JvmType]
private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType])
extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
with NullableColumnBuilder
private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(BOOLEAN) {
override def doAppendFrom(row: Row, ordinal: Int) {
appendValue(row.getBoolean(ordinal))
}
}
private[sql] abstract class NativeColumnBuilder[T <: NativeType](
override val columnStats: NativeColumnStats[T],
override val columnType: NativeColumnType[T])
extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType)
with NullableColumnBuilder
with AllCompressionSchemes
with CompressibleColumnBuilder[T]
private[sql] class IntColumnBuilder extends NativeColumnBuilder(INT) {
override def doAppendFrom(row: Row, ordinal: Int) {
appendValue(row.getInt(ordinal))
}
}
private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN)
private[sql] class ShortColumnBuilder extends NativeColumnBuilder(SHORT) {
override def doAppendFrom(row: Row, ordinal: Int) {
appendValue(row.getShort(ordinal))
}
}
private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
private[sql] class LongColumnBuilder extends NativeColumnBuilder(LONG) {
override def doAppendFrom(row: Row, ordinal: Int) {
appendValue(row.getLong(ordinal))
}
}
private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT)
private[sql] class ByteColumnBuilder extends NativeColumnBuilder(BYTE) {
override def doAppendFrom(row: Row, ordinal: Int) {
appendValue(row.getByte(ordinal))
}
}
private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG)
private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(DOUBLE) {
override def doAppendFrom(row: Row, ordinal: Int) {
appendValue(row.getDouble(ordinal))
}
}
private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE)
private[sql] class FloatColumnBuilder extends NativeColumnBuilder(FLOAT) {
override def doAppendFrom(row: Row, ordinal: Int) {
appendValue(row.getFloat(ordinal))
}
}
private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE)
private[sql] class StringColumnBuilder extends NativeColumnBuilder(STRING) {
override def doAppendFrom(row: Row, ordinal: Int) {
appendValue(row.getString(ordinal))
}
}
private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
private[sql] class BinaryColumnBuilder
extends BasicColumnBuilder[BinaryType.type, Array[Byte]]
with NullableColumnBuilder {
private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
def columnType = BINARY
override def doAppendFrom(row: Row, ordinal: Int) {
appendValue(row(ordinal).asInstanceOf[Array[Byte]])
}
}
private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY)
// TODO (lian) Add support for array, struct and map
private[sql] class GenericColumnBuilder
extends BasicColumnBuilder[DataType, Array[Byte]]
with NullableColumnBuilder {
def columnType = GENERIC
override def doAppendFrom(row: Row, ordinal: Int) {
val serialized = SparkSqlSerializer.serialize(row(ordinal))
buffer = ColumnBuilder.ensureFreeSpace(buffer, columnType.actualSize(serialized))
columnType.append(serialized, buffer)
}
}
private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC)
private[sql] object ColumnBuilder {
val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104

View file

@ -0,0 +1,360 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.types._
private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable {
/**
* Closed lower bound of this column.
*/
def lowerBound: JvmType
/**
* Closed upper bound of this column.
*/
def upperBound: JvmType
/**
* Gathers statistics information from `row(ordinal)`.
*/
def gatherStats(row: Row, ordinal: Int)
/**
* Returns `true` if `lower <= row(ordinal) <= upper`.
*/
def contains(row: Row, ordinal: Int): Boolean
/**
* Returns `true` if `row(ordinal) < upper` holds.
*/
def isAbove(row: Row, ordinal: Int): Boolean
/**
* Returns `true` if `lower < row(ordinal)` holds.
*/
def isBelow(row: Row, ordinal: Int): Boolean
/**
* Returns `true` if `row(ordinal) <= upper` holds.
*/
def isAtOrAbove(row: Row, ordinal: Int): Boolean
/**
* Returns `true` if `lower <= row(ordinal)` holds.
*/
def isAtOrBelow(row: Row, ordinal: Int): Boolean
}
private[sql] sealed abstract class NativeColumnStats[T <: NativeType]
extends ColumnStats[T, T#JvmType] {
type JvmType = T#JvmType
protected var (_lower, _upper) = initialBounds
def initialBounds: (JvmType, JvmType)
protected def columnType: NativeColumnType[T]
override def lowerBound: T#JvmType = _lower
override def upperBound: T#JvmType = _upper
override def isAtOrAbove(row: Row, ordinal: Int) = {
contains(row, ordinal) || isAbove(row, ordinal)
}
override def isAtOrBelow(row: Row, ordinal: Int) = {
contains(row, ordinal) || isBelow(row, ordinal)
}
}
private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] {
override def isAtOrBelow(row: Row, ordinal: Int) = true
override def isAtOrAbove(row: Row, ordinal: Int) = true
override def isBelow(row: Row, ordinal: Int) = true
override def isAbove(row: Row, ordinal: Int) = true
override def contains(row: Row, ordinal: Int) = true
override def gatherStats(row: Row, ordinal: Int) {}
override def upperBound = null.asInstanceOf[JvmType]
override def lowerBound = null.asInstanceOf[JvmType]
}
private[sql] abstract class BasicColumnStats[T <: NativeType](
protected val columnType: NativeColumnType[T])
extends NativeColumnStats[T]
private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) {
override def initialBounds = (true, false)
override def isBelow(row: Row, ordinal: Int) = {
lowerBound < columnType.getField(row, ordinal)
}
override def isAbove(row: Row, ordinal: Int) = {
columnType.getField(row, ordinal) < upperBound
}
override def contains(row: Row, ordinal: Int) = {
val field = columnType.getField(row, ordinal)
lowerBound <= field && field <= upperBound
}
override def gatherStats(row: Row, ordinal: Int) {
val field = columnType.getField(row, ordinal)
if (field > upperBound) _upper = field
if (field < lowerBound) _lower = field
}
}
private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) {
override def initialBounds = (Byte.MaxValue, Byte.MinValue)
override def isBelow(row: Row, ordinal: Int) = {
lowerBound < columnType.getField(row, ordinal)
}
override def isAbove(row: Row, ordinal: Int) = {
columnType.getField(row, ordinal) < upperBound
}
override def contains(row: Row, ordinal: Int) = {
val field = columnType.getField(row, ordinal)
lowerBound <= field && field <= upperBound
}
override def gatherStats(row: Row, ordinal: Int) {
val field = columnType.getField(row, ordinal)
if (field > upperBound) _upper = field
if (field < lowerBound) _lower = field
}
}
private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) {
override def initialBounds = (Short.MaxValue, Short.MinValue)
override def isBelow(row: Row, ordinal: Int) = {
lowerBound < columnType.getField(row, ordinal)
}
override def isAbove(row: Row, ordinal: Int) = {
columnType.getField(row, ordinal) < upperBound
}
override def contains(row: Row, ordinal: Int) = {
val field = columnType.getField(row, ordinal)
lowerBound <= field && field <= upperBound
}
override def gatherStats(row: Row, ordinal: Int) {
val field = columnType.getField(row, ordinal)
if (field > upperBound) _upper = field
if (field < lowerBound) _lower = field
}
}
private[sql] class LongColumnStats extends BasicColumnStats(LONG) {
override def initialBounds = (Long.MaxValue, Long.MinValue)
override def isBelow(row: Row, ordinal: Int) = {
lowerBound < columnType.getField(row, ordinal)
}
override def isAbove(row: Row, ordinal: Int) = {
columnType.getField(row, ordinal) < upperBound
}
override def contains(row: Row, ordinal: Int) = {
val field = columnType.getField(row, ordinal)
lowerBound <= field && field <= upperBound
}
override def gatherStats(row: Row, ordinal: Int) {
val field = columnType.getField(row, ordinal)
if (field > upperBound) _upper = field
if (field < lowerBound) _lower = field
}
}
private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) {
override def initialBounds = (Double.MaxValue, Double.MinValue)
override def isBelow(row: Row, ordinal: Int) = {
lowerBound < columnType.getField(row, ordinal)
}
override def isAbove(row: Row, ordinal: Int) = {
columnType.getField(row, ordinal) < upperBound
}
override def contains(row: Row, ordinal: Int) = {
val field = columnType.getField(row, ordinal)
lowerBound <= field && field <= upperBound
}
override def gatherStats(row: Row, ordinal: Int) {
val field = columnType.getField(row, ordinal)
if (field > upperBound) _upper = field
if (field < lowerBound) _lower = field
}
}
private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) {
override def initialBounds = (Float.MaxValue, Float.MinValue)
override def isBelow(row: Row, ordinal: Int) = {
lowerBound < columnType.getField(row, ordinal)
}
override def isAbove(row: Row, ordinal: Int) = {
columnType.getField(row, ordinal) < upperBound
}
override def contains(row: Row, ordinal: Int) = {
val field = columnType.getField(row, ordinal)
lowerBound <= field && field <= upperBound
}
override def gatherStats(row: Row, ordinal: Int) {
val field = columnType.getField(row, ordinal)
if (field > upperBound) _upper = field
if (field < lowerBound) _lower = field
}
}
private[sql] object IntColumnStats {
val UNINITIALIZED = 0
val INITIALIZED = 1
val ASCENDING = 2
val DESCENDING = 3
val UNORDERED = 4
}
/**
* Statistical information for `Int` columns. More information is collected since `Int` is
* frequently used. Extra information include:
*
* - Ordering state (ascending/descending/unordered), may be used to decide whether binary search
* is applicable when searching elements.
* - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression
* scheme.
*
* (This two kinds of information are not used anywhere yet and might be removed later.)
*/
private[sql] class IntColumnStats extends BasicColumnStats(INT) {
import IntColumnStats._
private var orderedState = UNINITIALIZED
private var lastValue: Int = _
private var _maxDelta: Int = _
def isAscending = orderedState != DESCENDING && orderedState != UNORDERED
def isDescending = orderedState != ASCENDING && orderedState != UNORDERED
def isOrdered = isAscending || isDescending
def maxDelta = _maxDelta
override def initialBounds = (Int.MaxValue, Int.MinValue)
override def isBelow(row: Row, ordinal: Int) = {
lowerBound < columnType.getField(row, ordinal)
}
override def isAbove(row: Row, ordinal: Int) = {
columnType.getField(row, ordinal) < upperBound
}
override def contains(row: Row, ordinal: Int) = {
val field = columnType.getField(row, ordinal)
lowerBound <= field && field <= upperBound
}
override def gatherStats(row: Row, ordinal: Int) {
val field = columnType.getField(row, ordinal)
if (field > upperBound) _upper = field
if (field < lowerBound) _lower = field
orderedState = orderedState match {
case UNINITIALIZED =>
lastValue = field
INITIALIZED
case INITIALIZED =>
// If all the integers in the column are the same, ordered state is set to Ascending.
// TODO (lian) Confirm whether this is the standard behaviour.
val nextState = if (field >= lastValue) ASCENDING else DESCENDING
_maxDelta = math.abs(field - lastValue)
lastValue = field
nextState
case ASCENDING if field < lastValue =>
UNORDERED
case DESCENDING if field > lastValue =>
UNORDERED
case state @ (ASCENDING | DESCENDING) =>
_maxDelta = _maxDelta.max(field - lastValue)
lastValue = field
state
case _ =>
orderedState
}
}
}
private[sql] class StringColumnStats extends BasicColumnStats(STRING) {
override def initialBounds = (null, null)
override def gatherStats(row: Row, ordinal: Int) {
val field = columnType.getField(row, ordinal)
if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
}
override def contains(row: Row, ordinal: Int) = {
!(upperBound eq null) && {
val field = columnType.getField(row, ordinal)
lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
}
}
override def isAbove(row: Row, ordinal: Int) = {
!(upperBound eq null) && {
val field = columnType.getField(row, ordinal)
field.compareTo(upperBound) < 0
}
}
override def isBelow(row: Row, ordinal: Int) = {
!(lowerBound eq null) && {
val field = columnType.getField(row, ordinal)
lowerBound.compareTo(field) < 0
}
}
}

View file

@ -19,7 +19,12 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.SparkSqlSerializer
/**
* An abstract class that represents type of a column. Used to append/extract Java objects into/from
@ -50,10 +55,24 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
*/
def actualSize(v: JvmType): Int = defaultSize
/**
* Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs
* whenever possible.
*/
def getField(row: Row, ordinal: Int): JvmType
/**
* Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing
* costs whenever possible.
*/
def setField(row: MutableRow, ordinal: Int, value: JvmType)
/**
* Creates a duplicated copy of the value.
*/
def clone(v: JvmType): JvmType = v
override def toString = getClass.getSimpleName.stripSuffix("$")
}
private[sql] abstract class NativeColumnType[T <: NativeType](
@ -65,7 +84,7 @@ private[sql] abstract class NativeColumnType[T <: NativeType](
/**
* Scala TypeTag. Can be used to create primitive arrays and hash tables.
*/
def scalaTag = dataType.tag
def scalaTag: TypeTag[dataType.JvmType] = dataType.tag
}
private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
@ -76,6 +95,12 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
def extract(buffer: ByteBuffer) = {
buffer.getInt()
}
override def setField(row: MutableRow, ordinal: Int, value: Int) {
row.setInt(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getInt(ordinal)
}
private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
@ -86,6 +111,12 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
override def extract(buffer: ByteBuffer) = {
buffer.getLong()
}
override def setField(row: MutableRow, ordinal: Int, value: Long) {
row.setLong(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getLong(ordinal)
}
private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
@ -96,6 +127,12 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
override def extract(buffer: ByteBuffer) = {
buffer.getFloat()
}
override def setField(row: MutableRow, ordinal: Int, value: Float) {
row.setFloat(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal)
}
private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
@ -106,6 +143,12 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
override def extract(buffer: ByteBuffer) = {
buffer.getDouble()
}
override def setField(row: MutableRow, ordinal: Int, value: Double) {
row.setDouble(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal)
}
private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
@ -116,6 +159,12 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
override def extract(buffer: ByteBuffer) = {
if (buffer.get() == 1) true else false
}
override def setField(row: MutableRow, ordinal: Int, value: Boolean) {
row.setBoolean(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal)
}
private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
@ -126,6 +175,12 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
override def extract(buffer: ByteBuffer) = {
buffer.get()
}
override def setField(row: MutableRow, ordinal: Int, value: Byte) {
row.setByte(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getByte(ordinal)
}
private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
@ -136,6 +191,12 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
override def extract(buffer: ByteBuffer) = {
buffer.getShort()
}
override def setField(row: MutableRow, ordinal: Int, value: Short) {
row.setShort(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getShort(ordinal)
}
private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
@ -152,6 +213,12 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
buffer.get(stringBytes, 0, length)
new String(stringBytes)
}
override def setField(row: MutableRow, ordinal: Int, value: String) {
row.setString(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getString(ordinal)
}
private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
@ -173,15 +240,27 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
}
}
private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16)
private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16) {
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
row(ordinal) = value
}
override def getField(row: Row, ordinal: Int) = row(ordinal).asInstanceOf[Array[Byte]]
}
// Used to process generic objects (all types other than those listed above). Objects should be
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16)
private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) {
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
}
override def getField(row: Row, ordinal: Int) = SparkSqlSerializer.serialize(row(ordinal))
}
private[sql] object ColumnType {
implicit def dataTypeToColumnType(dataType: DataType): ColumnType[_, _] = {
def apply(dataType: DataType): ColumnType[_, _] = {
dataType match {
case IntegerType => INT
case LongType => LONG

View file

@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute}
import org.apache.spark.sql.execution.{SparkPlan, LeafNode}
import org.apache.spark.sql.Row
/* Implicit conversions */
import org.apache.spark.sql.columnar.ColumnType._
private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], child: SparkPlan)
extends LeafNode {
@ -32,8 +29,8 @@ private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], ch
lazy val cachedColumnBuffers = {
val output = child.output
val cached = child.execute().mapPartitions { iterator =>
val columnBuilders = output.map { a =>
ColumnBuilder(a.dataType.typeId, 0, a.name)
val columnBuilders = output.map { attribute =>
ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name)
}.toArray
var row: Row = null

View file

@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
private var nextNullIndex: Int = _
private var pos: Int = 0
abstract override def initialize() {
abstract override protected def initialize() {
nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder())
nullCount = nullsBuffer.getInt()
nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1

View file

@ -22,10 +22,18 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.Row
/**
* Builds a nullable column. The byte buffer of a nullable column contains:
* - 4 bytes for the null count (number of nulls)
* - positions for each null, in ascending order
* - the non-null data (column data type, compression type, data...)
* A stackable trait used for building byte buffer for a column containing null values. Memory
* layout of the final byte buffer is:
* {{{
* .----------------------- Column type ID (4 bytes)
* | .------------------- Null count N (4 bytes)
* | | .--------------- Null positions (4 x N bytes, empty if null count is zero)
* | | | .--------- Non-null elements
* V V V V
* +---+---+-----+---------+
* | | | ... | ... ... |
* +---+---+-----+---------+
* }}}
*/
private[sql] trait NullableColumnBuilder extends ColumnBuilder {
private var nulls: ByteBuffer = _
@ -59,19 +67,8 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
nulls.limit(nullDataLen)
nulls.rewind()
// Column type ID is moved to the front, follows the null count, then non-null data
//
// +---------+
// | 4 bytes | Column type ID
// +---------+
// | 4 bytes | Null count
// +---------+
// | ... | Null positions (if null count is not zero)
// +---------+
// | ... | Non-null part (without column type ID)
// +---------+
val buffer = ByteBuffer
.allocate(4 + nullDataLen + nonNulls.limit)
.allocate(4 + 4 + nullDataLen + nonNulls.remaining())
.order(ByteOrder.nativeOrder())
.putInt(typeId)
.putInt(nullCount)

View file

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar.compression
import java.nio.ByteBuffer
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor}
private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor {
this: NativeColumnAccessor[T] =>
private var decoder: Decoder[T] = _
abstract override protected def initialize() = {
super.initialize()
decoder = CompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType)
}
abstract override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next()
}

View file

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar.compression
import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.{Logging, Row}
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder}
/**
* A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of
* the final byte buffer is:
* {{{
* .--------------------------- Column type ID (4 bytes)
* | .----------------------- Null count N (4 bytes)
* | | .------------------- Null positions (4 x N bytes, empty if null count is zero)
* | | | .------------- Compression scheme ID (4 bytes)
* | | | | .--------- Compressed non-null elements
* V V V V V
* +---+---+-----+---+---------+
* | | | ... | | ... ... |
* +---+---+-----+---+---------+
* \-----------/ \-----------/
* header body
* }}}
*/
private[sql] trait CompressibleColumnBuilder[T <: NativeType]
extends ColumnBuilder with Logging {
this: NativeColumnBuilder[T] with WithCompressionSchemes =>
import CompressionScheme._
val compressionEncoders = schemes.filter(_.supports(columnType)).map(_.encoder)
protected def isWorthCompressing(encoder: Encoder) = {
encoder.compressionRatio < 0.8
}
private def gatherCompressibilityStats(row: Row, ordinal: Int) {
val field = columnType.getField(row, ordinal)
var i = 0
while (i < compressionEncoders.length) {
compressionEncoders(i).gatherCompressibilityStats(field, columnType)
i += 1
}
}
abstract override def appendFrom(row: Row, ordinal: Int) {
super.appendFrom(row, ordinal)
gatherCompressibilityStats(row, ordinal)
}
abstract override def build() = {
val rawBuffer = super.build()
val encoder = {
val candidate = compressionEncoders.minBy(_.compressionRatio)
if (isWorthCompressing(candidate)) candidate else PassThrough.encoder
}
val headerSize = columnHeaderSize(rawBuffer)
val compressedSize = if (encoder.compressedSize == 0) {
rawBuffer.limit - headerSize
} else {
encoder.compressedSize
}
// Reserves 4 bytes for compression scheme ID
val compressedBuffer = ByteBuffer
.allocate(headerSize + 4 + compressedSize)
.order(ByteOrder.nativeOrder)
copyColumnHeader(rawBuffer, compressedBuffer)
logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
encoder.compress(rawBuffer, compressedBuffer, columnType)
}
}

View file

@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar.compression
import java.nio.ByteBuffer
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType}
private[sql] trait Encoder {
def gatherCompressibilityStats[T <: NativeType](
value: T#JvmType,
columnType: ColumnType[T, T#JvmType]) {}
def compressedSize: Int
def uncompressedSize: Int
def compressionRatio: Double = {
if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0
}
def compress[T <: NativeType](
from: ByteBuffer,
to: ByteBuffer,
columnType: ColumnType[T, T#JvmType]): ByteBuffer
}
private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType]
private[sql] trait CompressionScheme {
def typeId: Int
def supports(columnType: ColumnType[_, _]): Boolean
def encoder: Encoder
def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T]
}
private[sql] trait WithCompressionSchemes {
def schemes: Seq[CompressionScheme]
}
private[sql] trait AllCompressionSchemes extends WithCompressionSchemes {
override val schemes: Seq[CompressionScheme] = {
Seq(PassThrough, RunLengthEncoding, DictionaryEncoding)
}
}
private[sql] object CompressionScheme {
def apply(typeId: Int): CompressionScheme = typeId match {
case PassThrough.typeId => PassThrough
case _ => throw new UnsupportedOperationException()
}
def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) {
// Writes column type ID
to.putInt(from.getInt())
// Writes null count
val nullCount = from.getInt()
to.putInt(nullCount)
// Writes null positions
var i = 0
while (i < nullCount) {
to.putInt(from.getInt())
i += 1
}
}
def columnHeaderSize(columnBuffer: ByteBuffer): Int = {
val header = columnBuffer.duplicate()
val nullCount = header.getInt(4)
// Column type ID + null count + null positions
4 + 4 + 4 * nullCount
}
}

View file

@ -0,0 +1,288 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar.compression
import java.nio.ByteBuffer
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.runtimeMirror
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._
private[sql] case object PassThrough extends CompressionScheme {
override val typeId = 0
override def supports(columnType: ColumnType[_, _]) = true
override def encoder = new this.Encoder
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
new this.Decoder(buffer, columnType)
}
class Encoder extends compression.Encoder {
override def uncompressedSize = 0
override def compressedSize = 0
override def compress[T <: NativeType](
from: ByteBuffer,
to: ByteBuffer,
columnType: ColumnType[T, T#JvmType]) = {
// Writes compression type ID and copies raw contents
to.putInt(PassThrough.typeId).put(from).rewind()
to
}
}
class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
extends compression.Decoder[T] {
override def next() = columnType.extract(buffer)
override def hasNext = buffer.hasRemaining
}
}
private[sql] case object RunLengthEncoding extends CompressionScheme {
override def typeId = 1
override def encoder = new this.Encoder
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
new this.Decoder(buffer, columnType)
}
override def supports(columnType: ColumnType[_, _]) = columnType match {
case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true
case _ => false
}
class Encoder extends compression.Encoder {
private var _uncompressedSize = 0
private var _compressedSize = 0
// Using `MutableRow` to store the last value to avoid boxing/unboxing cost.
private val lastValue = new GenericMutableRow(1)
private var lastRun = 0
override def uncompressedSize = _uncompressedSize
override def compressedSize = _compressedSize
override def gatherCompressibilityStats[T <: NativeType](
value: T#JvmType,
columnType: ColumnType[T, T#JvmType]) {
val actualSize = columnType.actualSize(value)
_uncompressedSize += actualSize
if (lastValue.isNullAt(0)) {
columnType.setField(lastValue, 0, value)
lastRun = 1
_compressedSize += actualSize + 4
} else {
if (columnType.getField(lastValue, 0) == value) {
lastRun += 1
} else {
_compressedSize += actualSize + 4
columnType.setField(lastValue, 0, value)
lastRun = 1
}
}
}
override def compress[T <: NativeType](
from: ByteBuffer,
to: ByteBuffer,
columnType: ColumnType[T, T#JvmType]) = {
to.putInt(RunLengthEncoding.typeId)
if (from.hasRemaining) {
var currentValue = columnType.extract(from)
var currentRun = 1
while (from.hasRemaining) {
val value = columnType.extract(from)
if (value == currentValue) {
currentRun += 1
} else {
// Writes current run
columnType.append(currentValue, to)
to.putInt(currentRun)
// Resets current run
currentValue = value
currentRun = 1
}
}
// Writes the last run
columnType.append(currentValue, to)
to.putInt(currentRun)
}
to.rewind()
to
}
}
class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
extends compression.Decoder[T] {
private var run = 0
private var valueCount = 0
private var currentValue: T#JvmType = _
override def next() = {
if (valueCount == run) {
currentValue = columnType.extract(buffer)
run = buffer.getInt()
valueCount = 1
} else {
valueCount += 1
}
currentValue
}
override def hasNext = buffer.hasRemaining
}
}
private[sql] case object DictionaryEncoding extends CompressionScheme {
override def typeId: Int = 2
// 32K unique values allowed
private val MAX_DICT_SIZE = Short.MaxValue - 1
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
new this.Decoder[T](buffer, columnType)
}
override def encoder = new this.Encoder
override def supports(columnType: ColumnType[_, _]) = columnType match {
case INT | LONG | STRING => true
case _ => false
}
class Encoder extends compression.Encoder{
// Size of the input, uncompressed, in bytes. Note that we only count until the dictionary
// overflows.
private var _uncompressedSize = 0
// If the number of distinct elements is too large, we discard the use of dictionary encoding
// and set the overflow flag to true.
private var overflow = false
// Total number of elements.
private var count = 0
// The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself.
private var values = new mutable.ArrayBuffer[Any](1024)
// The dictionary that maps a value to the encoded short integer.
private val dictionary = mutable.HashMap.empty[Any, Short]
// Size of the serialized dictionary in bytes. Initialized to 4 since we need at least an `Int`
// to store dictionary element count.
private var dictionarySize = 4
override def gatherCompressibilityStats[T <: NativeType](
value: T#JvmType,
columnType: ColumnType[T, T#JvmType]) {
if (!overflow) {
val actualSize = columnType.actualSize(value)
count += 1
_uncompressedSize += actualSize
if (!dictionary.contains(value)) {
if (dictionary.size < MAX_DICT_SIZE) {
val clone = columnType.clone(value)
values += clone
dictionarySize += actualSize
dictionary(clone) = dictionary.size.toShort
} else {
overflow = true
values.clear()
dictionary.clear()
}
}
}
}
override def compress[T <: NativeType](
from: ByteBuffer,
to: ByteBuffer,
columnType: ColumnType[T, T#JvmType]) = {
if (overflow) {
throw new IllegalStateException(
"Dictionary encoding should not be used because of dictionary overflow.")
}
to.putInt(DictionaryEncoding.typeId)
.putInt(dictionary.size)
var i = 0
while (i < values.length) {
columnType.append(values(i).asInstanceOf[T#JvmType], to)
i += 1
}
while (from.hasRemaining) {
to.putShort(dictionary(columnType.extract(from)))
}
to.rewind()
to
}
override def uncompressedSize = _uncompressedSize
override def compressedSize = if (overflow) Int.MaxValue else dictionarySize + count * 2
}
class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
extends compression.Decoder[T] {
private val dictionary = {
// TODO Can we clean up this mess? Maybe move this to `DataType`?
implicit val classTag = {
val mirror = runtimeMirror(getClass.getClassLoader)
ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe))
}
Array.fill(buffer.getInt()) {
columnType.extract(buffer)
}
}
override def next() = dictionary(buffer.getShort())
override def hasNext = buffer.hasRemaining
}
}

View file

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types._
class ColumnStatsSuite extends FunSuite {
testColumnStats(classOf[BooleanColumnStats], BOOLEAN)
testColumnStats(classOf[ByteColumnStats], BYTE)
testColumnStats(classOf[ShortColumnStats], SHORT)
testColumnStats(classOf[IntColumnStats], INT)
testColumnStats(classOf[LongColumnStats], LONG)
testColumnStats(classOf[FloatColumnStats], FLOAT)
testColumnStats(classOf[DoubleColumnStats], DOUBLE)
testColumnStats(classOf[StringColumnStats], STRING)
def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]](
columnStatsClass: Class[U],
columnType: NativeColumnType[T]) {
val columnStatsName = columnStatsClass.getSimpleName
test(s"$columnStatsName: empty") {
val columnStats = columnStatsClass.newInstance()
expectResult(columnStats.initialBounds, "Wrong initial bounds") {
(columnStats.lowerBound, columnStats.upperBound)
}
}
test(s"$columnStatsName: non-empty") {
import ColumnarTestUtils._
val columnStats = columnStatsClass.newInstance()
val rows = Seq.fill(10)(makeRandomRow(columnType))
rows.foreach(columnStats.gatherStats(_, 0))
val values = rows.map(_.head.asInstanceOf[T#JvmType])
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
expectResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound)
expectResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound)
}
}
}

View file

@ -19,46 +19,56 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
import scala.util.Random
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
class ColumnTypeSuite extends FunSuite {
val columnTypes = Seq(INT, SHORT, LONG, BYTE, DOUBLE, FLOAT, STRING, BINARY, GENERIC)
val DEFAULT_BUFFER_SIZE = 512
test("defaultSize") {
val defaultSize = Seq(4, 2, 8, 1, 8, 4, 8, 16, 16)
val checks = Map(
INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
BOOLEAN -> 1, STRING -> 8, BINARY -> 16, GENERIC -> 16)
columnTypes.zip(defaultSize).foreach { case (columnType, size) =>
assert(columnType.defaultSize === size)
checks.foreach { case (columnType, expectedSize) =>
expectResult(expectedSize, s"Wrong defaultSize for $columnType") {
columnType.defaultSize
}
}
}
test("actualSize") {
val expectedSizes = Seq(4, 2, 8, 1, 8, 4, 4 + 5, 4 + 4, 4 + 11)
val actualSizes = Seq(
INT.actualSize(Int.MaxValue),
SHORT.actualSize(Short.MaxValue),
LONG.actualSize(Long.MaxValue),
BYTE.actualSize(Byte.MaxValue),
DOUBLE.actualSize(Double.MaxValue),
FLOAT.actualSize(Float.MaxValue),
STRING.actualSize("hello"),
BINARY.actualSize(new Array[Byte](4)),
GENERIC.actualSize(SparkSqlSerializer.serialize(Map(1 -> "a"))))
def checkActualSize[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
value: JvmType,
expected: Int) {
expectedSizes.zip(actualSizes).foreach { case (expected, actual) =>
assert(expected === actual)
expectResult(expected, s"Wrong actualSize for $columnType") {
columnType.actualSize(value)
}
}
checkActualSize(INT, Int.MaxValue, 4)
checkActualSize(SHORT, Short.MaxValue, 2)
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(BYTE, Byte.MaxValue, 1)
checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(BOOLEAN, true, 1)
checkActualSize(STRING, "hello", 4 + 5)
val binary = Array.fill[Byte](4)(0: Byte)
checkActualSize(BINARY, binary, 4 + 4)
val generic = Map(1 -> "a")
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11)
}
testNumericColumnType[BooleanType.type, Boolean](
testNativeColumnType[BooleanType.type](
BOOLEAN,
Array.fill(4)(Random.nextBoolean()),
ByteBuffer.allocate(32),
(buffer: ByteBuffer, v: Boolean) => {
buffer.put((if (v) 1 else 0).toByte)
},
@ -66,105 +76,42 @@ class ColumnTypeSuite extends FunSuite {
buffer.get() == 1
})
testNumericColumnType[IntegerType.type, Int](
INT,
Array.fill(4)(Random.nextInt()),
ByteBuffer.allocate(32),
(_: ByteBuffer).putInt(_),
(_: ByteBuffer).getInt)
testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt)
testNumericColumnType[ShortType.type, Short](
SHORT,
Array.fill(4)(Random.nextInt(Short.MaxValue).asInstanceOf[Short]),
ByteBuffer.allocate(32),
(_: ByteBuffer).putShort(_),
(_: ByteBuffer).getShort)
testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort)
testNumericColumnType[LongType.type, Long](
LONG,
Array.fill(4)(Random.nextLong()),
ByteBuffer.allocate(64),
(_: ByteBuffer).putLong(_),
(_: ByteBuffer).getLong)
testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong)
testNumericColumnType[ByteType.type, Byte](
BYTE,
Array.fill(4)(Random.nextInt(Byte.MaxValue).asInstanceOf[Byte]),
ByteBuffer.allocate(64),
(_: ByteBuffer).put(_),
(_: ByteBuffer).get)
testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get)
testNumericColumnType[DoubleType.type, Double](
DOUBLE,
Array.fill(4)(Random.nextDouble()),
ByteBuffer.allocate(64),
(_: ByteBuffer).putDouble(_),
(_: ByteBuffer).getDouble)
testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)
testNumericColumnType[FloatType.type, Float](
FLOAT,
Array.fill(4)(Random.nextFloat()),
ByteBuffer.allocate(64),
(_: ByteBuffer).putFloat(_),
(_: ByteBuffer).getFloat)
test("STRING") {
val buffer = ByteBuffer.allocate(128)
val seq = Array("hello", "world", "spark", "sql")
seq.map(_.getBytes).foreach { bytes: Array[Byte] =>
buffer.putInt(bytes.length).put(bytes)
}
buffer.rewind()
seq.foreach { s =>
assert(s === STRING.extract(buffer))
}
buffer.rewind()
seq.foreach(STRING.append(_, buffer))
buffer.rewind()
seq.foreach { s =>
val length = buffer.getInt
assert(length === s.getBytes.length)
testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)
testNativeColumnType[StringType.type](
STRING,
(buffer: ByteBuffer, string: String) => {
val bytes = string.getBytes()
buffer.putInt(bytes.length).put(string.getBytes)
},
(buffer: ByteBuffer) => {
val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
assert(s === new String(bytes))
}
}
new String(bytes)
})
test("BINARY") {
val buffer = ByteBuffer.allocate(128)
val seq = Array.fill(4) {
val bytes = new Array[Byte](4)
Random.nextBytes(bytes)
testColumnType[BinaryType.type, Array[Byte]](
BINARY,
(buffer: ByteBuffer, bytes: Array[Byte]) => {
buffer.putInt(bytes.length).put(bytes)
},
(buffer: ByteBuffer) => {
val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
bytes
}
seq.foreach { bytes =>
buffer.putInt(bytes.length).put(bytes)
}
buffer.rewind()
seq.foreach { b =>
assert(b === BINARY.extract(buffer))
}
buffer.rewind()
seq.foreach(BINARY.append(_, buffer))
buffer.rewind()
seq.foreach { b =>
val length = buffer.getInt
assert(length === b.length)
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
assert(b === bytes)
}
}
})
test("GENERIC") {
val buffer = ByteBuffer.allocate(512)
@ -177,43 +124,58 @@ class ColumnTypeSuite extends FunSuite {
val length = buffer.getInt()
assert(length === serializedObj.length)
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
assert(obj === SparkSqlSerializer.deserialize(bytes))
expectResult(obj, "Deserialized object didn't equal to the original object") {
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
SparkSqlSerializer.deserialize(bytes)
}
buffer.rewind()
buffer.putInt(serializedObj.length).put(serializedObj)
buffer.rewind()
assert(obj === SparkSqlSerializer.deserialize(GENERIC.extract(buffer)))
expectResult(obj, "Deserialized object didn't equal to the original object") {
buffer.rewind()
SparkSqlSerializer.deserialize(GENERIC.extract(buffer))
}
}
def testNumericColumnType[T <: DataType, JvmType](
def testNativeColumnType[T <: NativeType](
columnType: NativeColumnType[T],
putter: (ByteBuffer, T#JvmType) => Unit,
getter: (ByteBuffer) => T#JvmType) {
testColumnType[T, T#JvmType](columnType, putter, getter)
}
def testColumnType[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
seq: Seq[JvmType],
buffer: ByteBuffer,
putter: (ByteBuffer, JvmType) => Unit,
getter: (ByteBuffer) => JvmType) {
val columnTypeName = columnType.getClass.getSimpleName.stripSuffix("$")
val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
val seq = (0 until 4).map(_ => makeRandomValue(columnType))
test(s"$columnTypeName.extract") {
test(s"$columnType.extract") {
buffer.rewind()
seq.foreach(putter(buffer, _))
buffer.rewind()
seq.foreach { i =>
assert(i === columnType.extract(buffer))
seq.foreach { expected =>
assert(
expected === columnType.extract(buffer),
"Extracted value didn't equal to the original one")
}
}
test(s"$columnTypeName.append") {
test(s"$columnType.append") {
buffer.rewind()
seq.foreach(columnType.append(_, buffer))
buffer.rewind()
seq.foreach { i =>
assert(i === getter(buffer))
seq.foreach { expected =>
assert(
expected === getter(buffer),
"Extracted value didn't equal to the original one")
}
}
}

View file

@ -17,11 +17,11 @@
package org.apache.spark.sql.columnar
import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.sql.execution.SparkLogicalPlan
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{TestData, DslQuerySuite}
class ColumnarQuerySuite extends DslQuerySuite {
class ColumnarQuerySuite extends QueryTest {
import TestData._
import TestSQLContext._

View file

@ -1,55 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar
import scala.util.Random
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
// TODO Enrich test data
object ColumnarTestData {
object GenericMutableRow {
def apply(values: Any*) = {
val row = new GenericMutableRow(values.length)
row.indices.foreach { i =>
row(i) = values(i)
}
row
}
}
def randomBytes(length: Int) = {
val bytes = new Array[Byte](length)
Random.nextBytes(bytes)
bytes
}
val nonNullRandomRow = GenericMutableRow(
Random.nextInt(),
Random.nextLong(),
Random.nextFloat(),
Random.nextDouble(),
Random.nextBoolean(),
Random.nextInt(Byte.MaxValue).asInstanceOf[Byte],
Random.nextInt(Short.MaxValue).asInstanceOf[Short],
Random.nextString(Random.nextInt(64)),
randomBytes(Random.nextInt(64)),
Map(Random.nextInt() -> Random.nextString(4)))
val nullRow = GenericMutableRow(Seq.fill(10)(null): _*)
}

View file

@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar
import scala.collection.immutable.HashSet
import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.{DataType, NativeType}
object ColumnarTestUtils {
def makeNullRow(length: Int) = {
val row = new GenericMutableRow(length)
(0 until length).foreach(row.setNullAt)
row
}
def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = {
def randomBytes(length: Int) = {
val bytes = new Array[Byte](length)
Random.nextBytes(bytes)
bytes
}
(columnType match {
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
case INT => Random.nextInt()
case LONG => Random.nextLong()
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case STRING => Random.nextString(Random.nextInt(32))
case BOOLEAN => Random.nextBoolean()
case BINARY => randomBytes(Random.nextInt(32))
case _ =>
// Using a random one-element map instead of an arbitrary object
Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))
}).asInstanceOf[JvmType]
}
def makeRandomValues(
head: ColumnType[_ <: DataType, _],
tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)
def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = {
columnTypes.map(makeRandomValue(_))
}
def makeUniqueRandomValues[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
count: Int): Seq[JvmType] = {
Iterator.iterate(HashSet.empty[JvmType]) { set =>
set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
}.drop(count).next().toSeq
}
def makeRandomRow(
head: ColumnType[_ <: DataType, _],
tail: ColumnType[_ <: DataType, _]*): Row = makeRandomRow(Seq(head) ++ tail)
def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Row = {
val row = new GenericMutableRow(columnTypes.length)
makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
row(index) = value
}
row
}
def makeUniqueValuesAndSingleValueRows[T <: NativeType](
columnType: NativeColumnType[T],
count: Int) = {
val values = makeUniqueRandomValues(columnType, count)
val rows = values.map { value =>
val row = new GenericMutableRow(1)
row(0) = value
row
}
(values, rows)
}
}

View file

@ -17,12 +17,29 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.DataType
class TestNullableColumnAccessor[T <: DataType, JvmType](
buffer: ByteBuffer,
columnType: ColumnType[T, JvmType])
extends BasicColumnAccessor(buffer, columnType)
with NullableColumnAccessor
object TestNullableColumnAccessor {
def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = {
// Skips the column type ID
buffer.getInt()
new TestNullableColumnAccessor(buffer, columnType)
}
}
class NullableColumnAccessorSuite extends FunSuite {
import ColumnarTestData._
import ColumnarTestUtils._
Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach {
testNullableColumnAccessor(_)
@ -30,30 +47,32 @@ class NullableColumnAccessorSuite extends FunSuite {
def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
val nullRow = makeNullRow(1)
test(s"$typeName accessor: empty column") {
val builder = ColumnBuilder(columnType.typeId, 4)
val accessor = ColumnAccessor(builder.build())
test(s"Nullable $typeName column accessor: empty column") {
val builder = TestNullableColumnBuilder(columnType)
val accessor = TestNullableColumnAccessor(builder.build(), columnType)
assert(!accessor.hasNext)
}
test(s"$typeName accessor: access null values") {
val builder = ColumnBuilder(columnType.typeId, 4)
test(s"Nullable $typeName column accessor: access null values") {
val builder = TestNullableColumnBuilder(columnType)
val randomRow = makeRandomRow(columnType)
(0 until 4).foreach { _ =>
builder.appendFrom(nonNullRandomRow, columnType.typeId)
builder.appendFrom(nullRow, columnType.typeId)
builder.appendFrom(randomRow, 0)
builder.appendFrom(nullRow, 0)
}
val accessor = ColumnAccessor(builder.build())
val accessor = TestNullableColumnAccessor(builder.build(), columnType)
val row = new GenericMutableRow(1)
(0 until 4).foreach { _ =>
accessor.extractTo(row, 0)
assert(row(0) === nonNullRandomRow(columnType.typeId))
assert(row(0) === randomRow(0))
accessor.extractTo(row, 0)
assert(row(0) === null)
assert(row.isNullAt(0))
}
}
}

View file

@ -19,63 +19,71 @@ package org.apache.spark.sql.columnar
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.SparkSqlSerializer
class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType])
extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
with NullableColumnBuilder
object TestNullableColumnBuilder {
def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = {
val builder = new TestNullableColumnBuilder(columnType)
builder.initialize(initialSize)
builder
}
}
class NullableColumnBuilderSuite extends FunSuite {
import ColumnarTestData._
import ColumnarTestUtils._
Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach {
testNullableColumnBuilder(_)
}
def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
val columnBuilder = ColumnBuilder(columnType.typeId)
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
test(s"$typeName column builder: empty column") {
columnBuilder.initialize(4)
val columnBuilder = TestNullableColumnBuilder(columnType)
val buffer = columnBuilder.build()
// For column type ID
assert(buffer.getInt() === columnType.typeId)
// For null count
assert(buffer.getInt === 0)
expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
expectResult(0, "Wrong null count")(buffer.getInt())
assert(!buffer.hasRemaining)
}
test(s"$typeName column builder: buffer size auto growth") {
columnBuilder.initialize(4)
val columnBuilder = TestNullableColumnBuilder(columnType)
val randomRow = makeRandomRow(columnType)
(0 until 4) foreach { _ =>
columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId)
(0 until 4).foreach { _ =>
columnBuilder.appendFrom(randomRow, 0)
}
val buffer = columnBuilder.build()
// For column type ID
assert(buffer.getInt() === columnType.typeId)
// For null count
assert(buffer.getInt() === 0)
expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
expectResult(0, "Wrong null count")(buffer.getInt())
}
test(s"$typeName column builder: null values") {
columnBuilder.initialize(4)
val columnBuilder = TestNullableColumnBuilder(columnType)
val randomRow = makeRandomRow(columnType)
val nullRow = makeNullRow(1)
(0 until 4) foreach { _ =>
columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId)
columnBuilder.appendFrom(nullRow, columnType.typeId)
(0 until 4).foreach { _ =>
columnBuilder.appendFrom(randomRow, 0)
columnBuilder.appendFrom(nullRow, 0)
}
val buffer = columnBuilder.build()
// For column type ID
assert(buffer.getInt() === columnType.typeId)
// For null count
assert(buffer.getInt() === 4)
expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
expectResult(4, "Wrong null count")(buffer.getInt())
// For null positions
(1 to 7 by 2).foreach(i => assert(buffer.getInt() === i))
(1 to 7 by 2).foreach(expectResult(_, "Wrong null position")(buffer.getInt()))
// For non-null values
(0 until 4).foreach { _ =>
@ -84,7 +92,8 @@ class NullableColumnBuilderSuite extends FunSuite {
} else {
columnType.extract(buffer)
}
assert(actual === nonNullRandomRow(columnType.typeId))
assert(actual === randomRow(0), "Extracted value didn't equal to the original one")
}
assert(!buffer.hasRemaining)

View file

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar.compression
import java.nio.ByteBuffer
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
class DictionaryEncodingSuite extends FunSuite {
testDictionaryEncoding(new IntColumnStats, INT)
testDictionaryEncoding(new LongColumnStats, LONG)
testDictionaryEncoding(new StringColumnStats, STRING)
def testDictionaryEncoding[T <: NativeType](
columnStats: NativeColumnStats[T],
columnType: NativeColumnType[T]) {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
def buildDictionary(buffer: ByteBuffer) = {
(0 until buffer.getInt()).map(columnType.extract(buffer) -> _.toShort).toMap
}
test(s"$DictionaryEncoding with $typeName: simple case") {
// -------------
// Tests encoder
// -------------
val builder = TestCompressibleColumnBuilder(columnStats, columnType, DictionaryEncoding)
val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
builder.initialize(0)
builder.appendFrom(rows(0), 0)
builder.appendFrom(rows(1), 0)
builder.appendFrom(rows(0), 0)
builder.appendFrom(rows(1), 0)
val buffer = builder.build()
val headerSize = CompressionScheme.columnHeaderSize(buffer)
// 4 extra bytes for dictionary size
val dictionarySize = 4 + values.map(columnType.actualSize).sum
// 4 `Short`s, 2 bytes each
val compressedSize = dictionarySize + 2 * 4
// 4 extra bytes for compression scheme type ID
expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
// Skips column header
buffer.position(headerSize)
expectResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
val dictionary = buildDictionary(buffer)
Array[Short](0, 1).foreach { i =>
expectResult(i, "Wrong dictionary entry")(dictionary(values(i)))
}
Array[Short](0, 1, 0, 1).foreach {
expectResult(_, "Wrong column element value")(buffer.getShort())
}
// -------------
// Tests decoder
// -------------
// Rewinds, skips column header and 4 more bytes for compression scheme ID
buffer.rewind().position(headerSize + 4)
val decoder = new DictionaryEncoding.Decoder[T](buffer, columnType)
Array[Short](0, 1, 0, 1).foreach { i =>
expectResult(values(i), "Wrong decoded value")(decoder.next())
}
assert(!decoder.hasNext)
}
}
test(s"$DictionaryEncoding: overflow") {
val builder = TestCompressibleColumnBuilder(new IntColumnStats, INT, DictionaryEncoding)
builder.initialize(0)
(0 to Short.MaxValue).foreach { n =>
val row = new GenericMutableRow(1)
row.setInt(0, n)
builder.appendFrom(row, 0)
}
withClue("Dictionary overflowed, encoding should fail") {
intercept[Throwable] {
builder.build()
}
}
}
}

View file

@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar.compression
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
class RunLengthEncodingSuite extends FunSuite {
testRunLengthEncoding(new BooleanColumnStats, BOOLEAN)
testRunLengthEncoding(new ByteColumnStats, BYTE)
testRunLengthEncoding(new ShortColumnStats, SHORT)
testRunLengthEncoding(new IntColumnStats, INT)
testRunLengthEncoding(new LongColumnStats, LONG)
testRunLengthEncoding(new StringColumnStats, STRING)
def testRunLengthEncoding[T <: NativeType](
columnStats: NativeColumnStats[T],
columnType: NativeColumnType[T]) {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
test(s"$RunLengthEncoding with $typeName: simple case") {
// -------------
// Tests encoder
// -------------
val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding)
val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
builder.initialize(0)
builder.appendFrom(rows(0), 0)
builder.appendFrom(rows(0), 0)
builder.appendFrom(rows(1), 0)
builder.appendFrom(rows(1), 0)
val buffer = builder.build()
val headerSize = CompressionScheme.columnHeaderSize(buffer)
// 4 extra bytes each run for run length
val compressedSize = values.map(columnType.actualSize(_) + 4).sum
// 4 extra bytes for compression scheme type ID
expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
// Skips column header
buffer.position(headerSize)
expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
Array(0, 1).foreach { i =>
expectResult(values(i), "Wrong column element value")(columnType.extract(buffer))
expectResult(2, "Wrong run length")(buffer.getInt())
}
// -------------
// Tests decoder
// -------------
// Rewinds, skips column header and 4 more bytes for compression scheme ID
buffer.rewind().position(headerSize + 4)
val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType)
Array(0, 0, 1, 1).foreach { i =>
expectResult(values(i), "Wrong decoded value")(decoder.next())
}
assert(!decoder.hasNext)
}
test(s"$RunLengthEncoding with $typeName: run length == 1") {
// -------------
// Tests encoder
// -------------
val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding)
val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
builder.initialize(0)
builder.appendFrom(rows(0), 0)
builder.appendFrom(rows(1), 0)
val buffer = builder.build()
val headerSize = CompressionScheme.columnHeaderSize(buffer)
// 4 bytes each run for run length
val compressedSize = values.map(columnType.actualSize(_) + 4).sum
// 4 bytes for compression scheme type ID
expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
// Skips column header
buffer.position(headerSize)
expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
Array(0, 1).foreach { i =>
expectResult(values(i), "Wrong column element value")(columnType.extract(buffer))
expectResult(1, "Wrong run length")(buffer.getInt())
}
// -------------
// Tests decoder
// -------------
// Rewinds, skips column header and 4 more bytes for compression scheme ID
buffer.rewind().position(headerSize + 4)
val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType)
Array(0, 1).foreach { i =>
expectResult(values(i), "Wrong decoded value")(decoder.next())
}
assert(!decoder.hasNext)
}
}
}

View file

@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.columnar.compression
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._
class TestCompressibleColumnBuilder[T <: NativeType](
override val columnStats: NativeColumnStats[T],
override val columnType: NativeColumnType[T],
override val schemes: Seq[CompressionScheme])
extends NativeColumnBuilder(columnStats, columnType)
with NullableColumnBuilder
with CompressibleColumnBuilder[T] {
override protected def isWorthCompressing(encoder: Encoder) = true
}
object TestCompressibleColumnBuilder {
def apply[T <: NativeType](
columnStats: NativeColumnStats[T],
columnType: NativeColumnType[T],
scheme: CompressionScheme) = {
new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme))
}
}