[SPARK-8610] [SQL] Separate Row and InternalRow (part 2)

Currently, we use GenericRow both for Row and InternalRow, which is confusing because it could contain Scala type also Catalyst types.

This PR changes to use GenericInternalRow for InternalRow (contains catalyst types), GenericRow for Row (contains Scala types).

Also fixes some incorrect use of InternalRow or Row.

Author: Davies Liu <davies@databricks.com>

Closes #7003 from davies/internalrow and squashes the following commits:

d05866c [Davies Liu] fix test: rollback changes for pyspark
72878dd [Davies Liu] Merge branch 'master' of github.com:apache/spark into internalrow
efd0b25 [Davies Liu] fix copy of MutableRow
87b13cf [Davies Liu] fix test
d2ebd72 [Davies Liu] fix style
eb4b473 [Davies Liu] mark expensive API as final
bd4e99c [Davies Liu] Merge branch 'master' of github.com:apache/spark into internalrow
bdfb78f [Davies Liu] remove BaseMutableRow
6f99a97 [Davies Liu] fix catalyst test
defe931 [Davies Liu] remove BaseRow
288b31f [Davies Liu] Merge branch 'master' of github.com:apache/spark into internalrow
9d24350 [Davies Liu] separate Row and InternalRow (part 2)
This commit is contained in:
Davies Liu 2015-06-28 08:03:58 -07:00
parent 52d1281801
commit 77da5be6f1
39 changed files with 304 additions and 580 deletions

View file

@ -1,68 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql;
import org.apache.spark.sql.catalyst.expressions.MutableRow;
public abstract class BaseMutableRow extends BaseRow implements MutableRow {
@Override
public void update(int ordinal, Object value) {
throw new UnsupportedOperationException();
}
@Override
public void setInt(int ordinal, int value) {
throw new UnsupportedOperationException();
}
@Override
public void setLong(int ordinal, long value) {
throw new UnsupportedOperationException();
}
@Override
public void setDouble(int ordinal, double value) {
throw new UnsupportedOperationException();
}
@Override
public void setBoolean(int ordinal, boolean value) {
throw new UnsupportedOperationException();
}
@Override
public void setShort(int ordinal, short value) {
throw new UnsupportedOperationException();
}
@Override
public void setByte(int ordinal, byte value) {
throw new UnsupportedOperationException();
}
@Override
public void setFloat(int ordinal, float value) {
throw new UnsupportedOperationException();
}
@Override
public void setString(int ordinal, String value) {
throw new UnsupportedOperationException();
}
}

View file

@ -1,197 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.List;
import scala.collection.Seq;
import scala.collection.mutable.ArraySeq;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.StructType;
public abstract class BaseRow extends InternalRow {
@Override
final public int length() {
return size();
}
@Override
public boolean anyNull() {
final int n = size();
for (int i=0; i < n; i++) {
if (isNullAt(i)) {
return true;
}
}
return false;
}
@Override
public StructType schema() { throw new UnsupportedOperationException(); }
@Override
final public Object apply(int i) {
return get(i);
}
@Override
public int getInt(int i) {
throw new UnsupportedOperationException();
}
@Override
public long getLong(int i) {
throw new UnsupportedOperationException();
}
@Override
public float getFloat(int i) {
throw new UnsupportedOperationException();
}
@Override
public double getDouble(int i) {
throw new UnsupportedOperationException();
}
@Override
public byte getByte(int i) {
throw new UnsupportedOperationException();
}
@Override
public short getShort(int i) {
throw new UnsupportedOperationException();
}
@Override
public boolean getBoolean(int i) {
throw new UnsupportedOperationException();
}
@Override
public String getString(int i) {
throw new UnsupportedOperationException();
}
@Override
public BigDecimal getDecimal(int i) {
throw new UnsupportedOperationException();
}
@Override
public Date getDate(int i) {
throw new UnsupportedOperationException();
}
@Override
public Timestamp getTimestamp(int i) {
throw new UnsupportedOperationException();
}
@Override
public <T> Seq<T> getSeq(int i) {
throw new UnsupportedOperationException();
}
@Override
public <T> List<T> getList(int i) {
throw new UnsupportedOperationException();
}
@Override
public <K, V> scala.collection.Map<K, V> getMap(int i) {
throw new UnsupportedOperationException();
}
@Override
public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) {
throw new UnsupportedOperationException();
}
@Override
public <K, V> java.util.Map<K, V> getJavaMap(int i) {
throw new UnsupportedOperationException();
}
@Override
public Row getStruct(int i) {
throw new UnsupportedOperationException();
}
@Override
public <T> T getAs(int i) {
throw new UnsupportedOperationException();
}
@Override
public <T> T getAs(String fieldName) {
throw new UnsupportedOperationException();
}
@Override
public int fieldIndex(String name) {
throw new UnsupportedOperationException();
}
@Override
public InternalRow copy() {
final int n = size();
Object[] arr = new Object[n];
for (int i = 0; i < n; i++) {
arr[i] = get(i);
}
return new GenericRow(arr);
}
@Override
public Seq<Object> toSeq() {
final int n = size();
final ArraySeq<Object> values = new ArraySeq<Object>(n);
for (int i = 0; i < n; i++) {
values.update(i, get(i));
}
return values;
}
@Override
public String toString() {
return mkString("[", ",", "]");
}
@Override
public String mkString() {
return toSeq().mkString();
}
@Override
public String mkString(String sep) {
return toSeq().mkString(sep);
}
@Override
public String mkString(String start, String sep, String end) {
return toSeq().mkString(start, sep, end);
}
}

View file

@ -23,16 +23,12 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import scala.collection.Seq;
import scala.collection.mutable.ArraySeq;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.BaseMutableRow;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.types.DataTypes.*;
@ -52,7 +48,7 @@ import static org.apache.spark.sql.types.DataTypes.*;
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
public final class UnsafeRow extends BaseMutableRow {
public final class UnsafeRow extends MutableRow {
private Object baseObject;
private long baseOffset;
@ -63,6 +59,8 @@ public final class UnsafeRow extends BaseMutableRow {
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
public int length() { return numFields; }
/** The width of the null tracking bit set, in bytes */
private int bitSetWidthInBytes;
/**
@ -344,13 +342,4 @@ public final class UnsafeRow extends BaseMutableRow {
public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
}
@Override
public Seq<Object> toSeq() {
final ArraySeq<Object> values = new ArraySeq<Object>(numFields);
for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) {
values.update(fieldNumber, get(fieldNumber));
}
return values;
}
}

View file

@ -179,7 +179,7 @@ trait Row extends Serializable {
def get(i: Int): Any = apply(i)
/** Checks whether the value at position i is null. */
def isNullAt(i: Int): Boolean
def isNullAt(i: Int): Boolean = apply(i) == null
/**
* Returns the value at position i as a primitive boolean.
@ -187,7 +187,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getBoolean(i: Int): Boolean
def getBoolean(i: Int): Boolean = getAs[Boolean](i)
/**
* Returns the value at position i as a primitive byte.
@ -195,7 +195,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getByte(i: Int): Byte
def getByte(i: Int): Byte = getAs[Byte](i)
/**
* Returns the value at position i as a primitive short.
@ -203,7 +203,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getShort(i: Int): Short
def getShort(i: Int): Short = getAs[Short](i)
/**
* Returns the value at position i as a primitive int.
@ -211,7 +211,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getInt(i: Int): Int
def getInt(i: Int): Int = getAs[Int](i)
/**
* Returns the value at position i as a primitive long.
@ -219,7 +219,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getLong(i: Int): Long
def getLong(i: Int): Long = getAs[Long](i)
/**
* Returns the value at position i as a primitive float.
@ -228,7 +228,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getFloat(i: Int): Float
def getFloat(i: Int): Float = getAs[Float](i)
/**
* Returns the value at position i as a primitive double.
@ -236,7 +236,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getDouble(i: Int): Double
def getDouble(i: Int): Double = getAs[Double](i)
/**
* Returns the value at position i as a String object.
@ -244,35 +244,35 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getString(i: Int): String
def getString(i: Int): String = getAs[String](i)
/**
* Returns the value at position i of decimal type as java.math.BigDecimal.
*
* @throws ClassCastException when data type does not match.
*/
def getDecimal(i: Int): java.math.BigDecimal = apply(i).asInstanceOf[java.math.BigDecimal]
def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i)
/**
* Returns the value at position i of date type as java.sql.Date.
*
* @throws ClassCastException when data type does not match.
*/
def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]
def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i)
/**
* Returns the value at position i of date type as java.sql.Timestamp.
*
* @throws ClassCastException when data type does not match.
*/
def getTimestamp(i: Int): java.sql.Timestamp = apply(i).asInstanceOf[java.sql.Timestamp]
def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i)
/**
* Returns the value at position i of array type as a Scala Seq.
*
* @throws ClassCastException when data type does not match.
*/
def getSeq[T](i: Int): Seq[T] = apply(i).asInstanceOf[Seq[T]]
def getSeq[T](i: Int): Seq[T] = getAs[Seq[T]](i)
/**
* Returns the value at position i of array type as [[java.util.List]].
@ -288,7 +288,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
def getMap[K, V](i: Int): scala.collection.Map[K, V] = apply(i).asInstanceOf[Map[K, V]]
def getMap[K, V](i: Int): scala.collection.Map[K, V] = getAs[Map[K, V]](i)
/**
* Returns the value at position i of array type as a [[java.util.Map]].
@ -366,9 +366,18 @@ trait Row extends Serializable {
/* ---------------------- utility methods for Scala ---------------------- */
/**
* Return a Scala Seq representing the row. ELements are placed in the same order in the Seq.
* Return a Scala Seq representing the row. Elements are placed in the same order in the Seq.
*/
def toSeq: Seq[Any]
def toSeq: Seq[Any] = {
val n = length
val values = new Array[Any](n)
var i = 0
while (i < n) {
values.update(i, get(i))
i += 1
}
values.toSeq
}
/** Displays all elements of this sequence in a string (without a separator). */
def mkString: String = toSeq.mkString

View file

@ -242,7 +242,7 @@ object CatalystTypeConverters {
ar(idx) = converters(idx).toCatalyst(row(idx))
idx += 1
}
new GenericRowWithSchema(ar, structType)
new GenericInternalRow(ar)
case p: Product =>
val ar = new Array[Any](structType.size)
@ -252,7 +252,7 @@ object CatalystTypeConverters {
ar(idx) = converters(idx).toCatalyst(iter.next())
idx += 1
}
new GenericRowWithSchema(ar, structType)
new GenericInternalRow(ar)
}
override def toScala(row: InternalRow): Row = {

View file

@ -19,14 +19,38 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.unsafe.types.UTF8String
/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
* internal types.
*/
abstract class InternalRow extends Row {
// This is only use for test
override def getString(i: Int): String = getAs[UTF8String](i).toString
// These expensive API should not be used internally.
final override def getDecimal(i: Int): java.math.BigDecimal =
throw new UnsupportedOperationException
final override def getDate(i: Int): java.sql.Date =
throw new UnsupportedOperationException
final override def getTimestamp(i: Int): java.sql.Timestamp =
throw new UnsupportedOperationException
final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException
final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException
final override def getMap[K, V](i: Int): scala.collection.Map[K, V] =
throw new UnsupportedOperationException
final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] =
throw new UnsupportedOperationException
final override def getStruct(i: Int): Row = throw new UnsupportedOperationException
final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException
final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] =
throw new UnsupportedOperationException
// A default implementation to change the return type
override def copy(): InternalRow = this
override def apply(i: Int): Any = get(i)
override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[Row]) {
@ -93,27 +117,15 @@ abstract class InternalRow extends Row {
}
object InternalRow {
def unapplySeq(row: InternalRow): Some[Seq[Any]] = Some(row.toSeq)
/**
* This method can be used to construct a [[Row]] with the given values.
*/
def apply(values: Any*): InternalRow = new GenericRow(values.toArray)
def apply(values: Any*): InternalRow = new GenericInternalRow(values.toArray)
/**
* This method can be used to construct a [[Row]] from a [[Seq]] of values.
*/
def fromSeq(values: Seq[Any]): InternalRow = new GenericRow(values.toArray)
def fromTuple(tuple: Product): InternalRow = fromSeq(tuple.productIterator.toSeq)
/**
* Merge multiple rows into a single row, one after another.
*/
def merge(rows: InternalRow*): InternalRow = {
// TODO: Improve the performance of this if used in performance critical part.
new GenericRow(rows.flatMap(_.toSeq).toArray)
}
def fromSeq(values: Seq[Any]): InternalRow = new GenericInternalRow(values.toArray)
/** Returns an empty row. */
val empty = apply()

View file

@ -36,7 +36,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
outputArray(i) = exprArray(i).eval(input)
i += 1
}
new GenericRow(outputArray)
new GenericInternalRow(outputArray)
}
override def toString: String = s"Row => [${exprArray.mkString(",")}]"
@ -135,12 +135,6 @@ class JoinedRow extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
override def getString(i: Int): String =
if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@ -149,7 +143,7 @@ class JoinedRow extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
new GenericRow(copiedValues)
new GenericInternalRow(copiedValues)
}
override def toString: String = {
@ -235,12 +229,6 @@ class JoinedRow2 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
override def getString(i: Int): String =
if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@ -249,7 +237,7 @@ class JoinedRow2 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
new GenericRow(copiedValues)
new GenericInternalRow(copiedValues)
}
override def toString: String = {
@ -329,12 +317,6 @@ class JoinedRow3 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
override def getString(i: Int): String =
if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@ -343,7 +325,7 @@ class JoinedRow3 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
new GenericRow(copiedValues)
new GenericInternalRow(copiedValues)
}
override def toString: String = {
@ -423,12 +405,6 @@ class JoinedRow4 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
override def getString(i: Int): String =
if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@ -437,7 +413,7 @@ class JoinedRow4 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
new GenericRow(copiedValues)
new GenericInternalRow(copiedValues)
}
override def toString: String = {
@ -517,12 +493,6 @@ class JoinedRow5 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
override def getString(i: Int): String =
if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@ -531,7 +501,7 @@ class JoinedRow5 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
new GenericRow(copiedValues)
new GenericInternalRow(copiedValues)
}
override def toString: String = {
@ -611,12 +581,6 @@ class JoinedRow6 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
override def getString(i: Int): String =
if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@ -625,7 +589,7 @@ class JoinedRow6 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
new GenericRow(copiedValues)
new GenericInternalRow(copiedValues)
}
override def toString: String = {

View file

@ -230,7 +230,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
i += 1
}
new GenericRow(newValues)
new GenericInternalRow(newValues)
}
override def update(ordinal: Int, value: Any) {

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
// MutableProjection is not accessible in Java
abstract class BaseMutableProjection extends MutableProjection {}
abstract class BaseMutableProjection extends MutableProjection
/**
* Generates byte code that produces a [[MutableRow]] object that can update itself based on a new

View file

@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.BaseMutableRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@ -149,6 +148,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
"""
}.mkString("\n")
val copyColumns = expressions.zipWithIndex.map { case (e, i) =>
s"""arr[$i] = c$i;"""
}.mkString("\n ")
val code = s"""
public SpecificProjection generate($exprType[] expr) {
return new SpecificProjection(expr);
@ -167,7 +170,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
}
final class SpecificRow extends ${typeOf[BaseMutableRow]} {
final class SpecificRow extends ${typeOf[MutableRow]} {
$columns
@ -175,7 +178,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
$initColumns
}
public int size() { return ${expressions.length};}
public int length() { return ${expressions.length};}
protected boolean[] nullBits = new boolean[${expressions.length}];
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
@ -216,6 +219,13 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
return super.equals(other);
}
@Override
public InternalRow copy() {
Object[] arr = new Object[${expressions.length}];
${copyColumns}
return new ${typeOf[GenericInternalRow]}(arr);
}
}
"""

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.types._
@ -68,19 +69,19 @@ abstract class Generator extends Expression {
*/
case class UserDefinedGenerator(
elementTypes: Seq[(DataType, Boolean)],
function: InternalRow => TraversableOnce[InternalRow],
function: Row => TraversableOnce[InternalRow],
children: Seq[Expression])
extends Generator {
@transient private[this] var inputRow: InterpretedProjection = _
@transient private[this] var convertToScala: (InternalRow) => InternalRow = _
@transient private[this] var convertToScala: (InternalRow) => Row = _
private def initializeConverters(): Unit = {
inputRow = new InterpretedProjection(children)
convertToScala = {
val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
CatalystTypeConverters.createToScalaConverter(inputSchema)
}.asInstanceOf[(InternalRow => InternalRow)]
}.asInstanceOf[InternalRow => Row]
}
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@ -118,10 +119,11 @@ case class Explode(child: Expression)
child.dataType match {
case ArrayType(_, _) =>
val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v)))
if (inputArray == null) Nil else inputArray.map(v => InternalRow(v))
case MapType(_, _, _) =>
val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]]
if (inputMap == null) Nil else inputMap.map { case (k, v) => new GenericRow(Array(k, v)) }
if (inputMap == null) Nil
else inputMap.map { case (k, v) => InternalRow(k, v) }
}
}

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DataType, StructType, AtomicType}
import org.apache.spark.unsafe.types.UTF8String
@ -24,19 +25,32 @@ import org.apache.spark.unsafe.types.UTF8String
* An extended interface to [[InternalRow]] that allows the values for each column to be updated.
* Setting a value through a primitive function implicitly marks that column as not null.
*/
trait MutableRow extends InternalRow {
abstract class MutableRow extends InternalRow {
def setNullAt(i: Int): Unit
def update(ordinal: Int, value: Any)
def update(i: Int, value: Any)
def setInt(ordinal: Int, value: Int)
def setLong(ordinal: Int, value: Long)
def setDouble(ordinal: Int, value: Double)
def setBoolean(ordinal: Int, value: Boolean)
def setShort(ordinal: Int, value: Short)
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
// default implementation (slow)
def setInt(i: Int, value: Int): Unit = { update(i, value) }
def setLong(i: Int, value: Long): Unit = { update(i, value) }
def setDouble(i: Int, value: Double): Unit = { update(i, value) }
def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) }
def setShort(i: Int, value: Short): Unit = { update(i, value) }
def setByte(i: Int, value: Byte): Unit = { update(i, value) }
def setFloat(i: Int, value: Float): Unit = { update(i, value) }
def setString(i: Int, value: String): Unit = {
update(i, UTF8String.fromString(value))
}
override def copy(): InternalRow = {
val arr = new Array[Any](length)
var i = 0
while (i < length) {
arr(i) = get(i)
i += 1
}
new GenericInternalRow(arr)
}
}
/**
@ -59,69 +73,58 @@ object EmptyRow extends InternalRow {
override def copy(): InternalRow = this
}
/**
* A row implementation that uses an array of objects as the underlying storage.
*/
trait ArrayBackedRow {
self: Row =>
protected val values: Array[Any]
override def toSeq: Seq[Any] = values.toSeq
def length: Int = values.length
override def apply(i: Int): Any = values(i)
def setNullAt(i: Int): Unit = { values(i) = null}
def update(i: Int, value: Any): Unit = { values(i) = value }
}
/**
* A row implementation that uses an array of objects as the underlying storage. Note that, while
* the array is not copied, and thus could technically be mutated after creation, this is not
* allowed.
*/
class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow {
class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
def this(size: Int) = this(new Array[Any](size))
override def toSeq: Seq[Any] = values.toSeq
override def length: Int = values.length
override def apply(i: Int): Any = values(i)
override def isNullAt(i: Int): Boolean = values(i) == null
override def getInt(i: Int): Int = {
if (values(i) == null) sys.error("Failed to check null bit for primitive int value.")
values(i).asInstanceOf[Int]
// This is used by test or outside
override def equals(o: Any): Boolean = o match {
case other: Row if other.length == length =>
var i = 0
while (i < length) {
if (isNullAt(i) != other.isNullAt(i)) {
return false
}
val equal = (apply(i), other.apply(i)) match {
case (a: Array[Byte], b: Array[Byte]) => java.util.Arrays.equals(a, b)
case (a, b) => a == b
}
if (!equal) {
return false
}
i += 1
}
true
case _ => false
}
override def getLong(i: Int): Long = {
if (values(i) == null) sys.error("Failed to check null bit for primitive long value.")
values(i).asInstanceOf[Long]
}
override def getDouble(i: Int): Double = {
if (values(i) == null) sys.error("Failed to check null bit for primitive double value.")
values(i).asInstanceOf[Double]
}
override def getFloat(i: Int): Float = {
if (values(i) == null) sys.error("Failed to check null bit for primitive float value.")
values(i).asInstanceOf[Float]
}
override def getBoolean(i: Int): Boolean = {
if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.")
values(i).asInstanceOf[Boolean]
}
override def getShort(i: Int): Short = {
if (values(i) == null) sys.error("Failed to check null bit for primitive short value.")
values(i).asInstanceOf[Short]
}
override def getByte(i: Int): Byte = {
if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.")
values(i).asInstanceOf[Byte]
}
override def getString(i: Int): String = {
values(i) match {
case null => null
case s: String => s
case utf8: UTF8String => utf8.toString
}
}
override def copy(): InternalRow = this
override def copy(): Row = this
}
class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
@ -133,31 +136,29 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
override def fieldIndex(name: String): Int = schema.fieldIndex(name)
}
class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
/**
* A internal row implementation that uses an array of objects as the underlying storage.
* Note that, while the array is not copied, and thus could technically be mutated after creation,
* this is not allowed.
*/
class GenericInternalRow(protected[sql] val values: Array[Any])
extends InternalRow with ArrayBackedRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
def this(size: Int) = this(new Array[Any](size))
override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value }
override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value }
override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value }
override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
override def setString(ordinal: Int, value: String): Unit = {
values(ordinal) = UTF8String.fromString(value)
}
override def setNullAt(i: Int): Unit = { values(i) = null }
override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value }
override def copy(): InternalRow = new GenericRow(values.clone())
override def copy(): InternalRow = this
}
class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
def this(size: Int) = this(new Array[Any](size))
override def copy(): InternalRow = new GenericInternalRow(values.clone())
}
class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =

View file

@ -33,7 +33,7 @@ trait ExpressionEvalHelper {
self: SparkFunSuite =>
protected def create_row(values: Any*): InternalRow = {
new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst))
}
protected def checkEvaluation(
@ -122,7 +122,7 @@ trait ExpressionEvalHelper {
}
val actual = plan(inputRow)
val expectedRow = new GenericRow(Array[Any](expected))
val expectedRow = InternalRow(expected)
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""

View file

@ -37,7 +37,7 @@ class UnsafeFixedWidthAggregationMapSuite
private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
private def emptyAggregationBuffer: InternalRow = new GenericRow(Array[Any](0))
private def emptyAggregationBuffer: InternalRow = InternalRow(0)
private var memoryManager: TaskMemoryManager = null
@ -84,7 +84,7 @@ class UnsafeFixedWidthAggregationMapSuite
1024, // initial capacity
false // disable perf metrics
)
val groupKey = new GenericRow(Array[Any](UTF8String.fromString("cats")))
val groupKey = InternalRow(UTF8String.fromString("cats"))
// Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
map.getAggregationBuffer(groupKey)
@ -113,7 +113,7 @@ class UnsafeFixedWidthAggregationMapSuite
val rand = new Random(42)
val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet
groupKeys.foreach { keyString =>
map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String.fromString(keyString))))
map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
}
val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
entry.key.getString(0)

View file

@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution.{Filter, _}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
/**
@ -377,10 +378,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setInt(0, v)
row: Row
row: InternalRow
}
}
DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
DataFrameHolder(
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/**
@ -393,10 +395,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setLong(0, v)
row: Row
row: InternalRow
}
}
DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
DataFrameHolder(
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/**
@ -408,11 +411,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setString(0, v)
row: Row
row.update(0, UTF8String.fromString(v))
row: InternalRow
}
}
DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
DataFrameHolder(
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
}
@ -559,9 +563,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
(e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
}
iter.map { row =>
new GenericRow(
new GenericInternalRow(
methodsToConverts.map { case (e, convert) => convert(e.invoke(row)) }.toArray[Any]
) : InternalRow
): InternalRow
}
}
DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this))
@ -1065,7 +1069,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
val rowRdd = convertedRdd.mapPartitions { iter =>
iter.map { m => new GenericRow(m): InternalRow}
iter.map { m => new GenericInternalRow(m): InternalRow}
}
DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self))

View file

@ -21,7 +21,7 @@ import java.nio.ByteBuffer
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types._
@ -63,7 +63,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
* Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this
* method to avoid boxing/unboxing costs whenever possible.
*/
def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
append(getField(row, ordinal), buffer)
}
@ -71,13 +71,13 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
* Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable
* length types such as byte arrays and strings.
*/
def actualSize(row: Row, ordinal: Int): Int = defaultSize
def actualSize(row: InternalRow, ordinal: Int): Int = defaultSize
/**
* Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs
* whenever possible.
*/
def getField(row: Row, ordinal: Int): JvmType
def getField(row: InternalRow, ordinal: Int): JvmType
/**
* Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing
@ -89,7 +89,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
* Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid
* boxing/unboxing costs whenever possible.
*/
def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
to(toOrdinal) = from(fromOrdinal)
}
@ -118,7 +118,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
buffer.putInt(v)
}
override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
buffer.putInt(row.getInt(ordinal))
}
@ -134,9 +134,9 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
row.setInt(ordinal, value)
}
override def getField(row: Row, ordinal: Int): Int = row.getInt(ordinal)
override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal)
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
to.setInt(toOrdinal, from.getInt(fromOrdinal))
}
}
@ -146,7 +146,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
buffer.putLong(v)
}
override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
buffer.putLong(row.getLong(ordinal))
}
@ -162,9 +162,9 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
row.setLong(ordinal, value)
}
override def getField(row: Row, ordinal: Int): Long = row.getLong(ordinal)
override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal)
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
to.setLong(toOrdinal, from.getLong(fromOrdinal))
}
}
@ -174,7 +174,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
buffer.putFloat(v)
}
override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
buffer.putFloat(row.getFloat(ordinal))
}
@ -190,9 +190,9 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
row.setFloat(ordinal, value)
}
override def getField(row: Row, ordinal: Int): Float = row.getFloat(ordinal)
override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal)
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
to.setFloat(toOrdinal, from.getFloat(fromOrdinal))
}
}
@ -202,7 +202,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
buffer.putDouble(v)
}
override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
buffer.putDouble(row.getDouble(ordinal))
}
@ -218,9 +218,9 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
row.setDouble(ordinal, value)
}
override def getField(row: Row, ordinal: Int): Double = row.getDouble(ordinal)
override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal)
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
to.setDouble(toOrdinal, from.getDouble(fromOrdinal))
}
}
@ -230,7 +230,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
buffer.put(if (v) 1: Byte else 0: Byte)
}
override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte)
}
@ -244,9 +244,9 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
row.setBoolean(ordinal, value)
}
override def getField(row: Row, ordinal: Int): Boolean = row.getBoolean(ordinal)
override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal)
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal))
}
}
@ -256,7 +256,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
buffer.put(v)
}
override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
buffer.put(row.getByte(ordinal))
}
@ -272,9 +272,9 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
row.setByte(ordinal, value)
}
override def getField(row: Row, ordinal: Int): Byte = row.getByte(ordinal)
override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal)
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
to.setByte(toOrdinal, from.getByte(fromOrdinal))
}
}
@ -284,7 +284,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
buffer.putShort(v)
}
override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
buffer.putShort(row.getShort(ordinal))
}
@ -300,15 +300,15 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
row.setShort(ordinal, value)
}
override def getField(row: Row, ordinal: Int): Short = row.getShort(ordinal)
override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal)
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
to.setShort(toOrdinal, from.getShort(fromOrdinal))
}
}
private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
override def actualSize(row: Row, ordinal: Int): Int = {
override def actualSize(row: InternalRow, ordinal: Int): Int = {
row.getString(ordinal).getBytes("utf-8").length + 4
}
@ -328,11 +328,11 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
row.update(ordinal, value)
}
override def getField(row: Row, ordinal: Int): UTF8String = {
override def getField(row: InternalRow, ordinal: Int): UTF8String = {
row(ordinal).asInstanceOf[UTF8String]
}
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
to.update(toOrdinal, from(fromOrdinal))
}
}
@ -346,7 +346,7 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) {
buffer.putInt(v)
}
override def getField(row: Row, ordinal: Int): Int = {
override def getField(row: InternalRow, ordinal: Int): Int = {
row(ordinal).asInstanceOf[Int]
}
@ -364,7 +364,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) {
buffer.putLong(v)
}
override def getField(row: Row, ordinal: Int): Long = {
override def getField(row: InternalRow, ordinal: Int): Long = {
row(ordinal).asInstanceOf[Long]
}
@ -387,7 +387,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
buffer.putLong(v.toUnscaledLong)
}
override def getField(row: Row, ordinal: Int): Decimal = {
override def getField(row: InternalRow, ordinal: Int): Decimal = {
row(ordinal).asInstanceOf[Decimal]
}
@ -405,7 +405,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
defaultSize: Int)
extends ColumnType[T, Array[Byte]](typeId, defaultSize) {
override def actualSize(row: Row, ordinal: Int): Int = {
override def actualSize(row: InternalRow, ordinal: Int): Int = {
getField(row, ordinal).length + 4
}
@ -426,7 +426,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16)
row(ordinal) = value
}
override def getField(row: Row, ordinal: Int): Array[Byte] = {
override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
row(ordinal).asInstanceOf[Array[Byte]]
}
}
@ -439,7 +439,7 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) {
row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
}
override def getField(row: Row, ordinal: Int): Array[Byte] = {
override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
SparkSqlSerializer.serialize(row(ordinal))
}
}

View file

@ -146,7 +146,8 @@ private[sql] case class InMemoryRelation(
rowCount += 1
}
val stats = InternalRow.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*)
val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
.flatMap(_.toSeq))
batchStats += stats
CachedBatch(columnBuilders.map(_.build().array()), stats)

View file

@ -20,22 +20,20 @@ package org.apache.spark.sql.execution
import java.nio.ByteBuffer
import java.util.{HashMap => JavaHashMap}
import org.apache.spark.sql.types.Decimal
import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Serializer, Kryo}
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.twitter.chill.ResourcePool
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.{SerializerInstance, KryoSerializer}
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.util.MutablePair
import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.{SparkConf, SparkEnv}
private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
override def newKryo(): Kryo = {
@ -43,6 +41,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.setRegistrationRequired(false)
kryo.register(classOf[MutablePair[_, _]])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
@ -139,7 +138,7 @@ private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
val iterator = hs.iterator
while(iterator.hasNext) {
val row = iterator.next()
rowSerializer.write(kryo, output, row.asInstanceOf[GenericRow].values)
rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values)
}
}
@ -150,7 +149,7 @@ private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
var i = 0
while (i < numItems) {
val row =
new GenericRow(rowSerializer.read(
new GenericInternalRow(rowSerializer.read(
kryo,
input,
classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]])

View file

@ -26,7 +26,8 @@ import scala.reflect.ClassTag
import org.apache.spark.Logging
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, MutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -329,7 +330,7 @@ private[sql] object SparkSqlSerializer2 {
*/
def createDeserializationFunction(
schema: Array[DataType],
in: DataInputStream): (MutableRow) => Row = {
in: DataInputStream): (MutableRow) => InternalRow = {
if (schema == null) {
(mutableRow: MutableRow) => null
} else {

View file

@ -210,8 +210,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
protected lazy val singleRowRdd =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1)
protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
object TakeOrderedAndProject extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {

View file

@ -71,8 +71,8 @@ case class HashOuterJoin(
@transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null)
@transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow]
@transient private[this] lazy val leftNullRow = new GenericRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericRow(right.output.length)
@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
@transient private[this] lazy val boundCondition =
condition.map(
newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true)

View file

@ -183,9 +183,9 @@ object EvaluatePython {
}.toMap
case (c, StructType(fields)) if c.getClass.isArray =>
new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map {
new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map {
case (e, f) => fromJava(e, f.dataType)
}): Row
})
case (c: java.util.Calendar, DateType) =>
DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis))

View file

@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
private[sql] object StatFunctions extends Logging {
@ -123,7 +124,7 @@ private[sql] object StatFunctions extends Logging {
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
countsRow.setString(0, col1Item.toString)
countsRow.update(0, UTF8String.fromString(col1Item.toString))
countsRow
}.toSeq
val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq

View file

@ -417,7 +417,7 @@ private[sql] class JDBCRDD(
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
case StringConversion => mutableRow.setString(i, rs.getString(pos))
case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos)))
case TimestampConversion =>
val t = rs.getTimestamp(pos)
if (t != null) {

View file

@ -318,7 +318,7 @@ private[parquet] class CatalystGroupConverter(
// Note: this will ever only be called in the root converter when the record has been
// fully processed. Therefore it will be difficult to use mutable rows instead, since
// any non-root converter never would be sure when it would be safe to re-use the buffer.
new GenericRow(current.toArray)
new GenericInternalRow(current.toArray)
}
override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)
@ -342,8 +342,8 @@ private[parquet] class CatalystGroupConverter(
override def end(): Unit = {
if (!isRootConverter) {
assert(current != null) // there should be no empty groups
buffer.append(new GenericRow(current.toArray))
parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]]))
buffer.append(new GenericInternalRow(current.toArray))
parent.updateField(index, new GenericInternalRow(buffer.toArray.asInstanceOf[Array[Any]]))
}
}
}
@ -788,7 +788,7 @@ private[parquet] class CatalystStructConverter(
// here we need to make sure to use StructScalaType
// Note: we need to actually make a copy of the array since we
// may be in a nested field
parent.updateField(index, new GenericRow(current.toArray))
parent.updateField(index, new GenericInternalRow(current.toArray))
}
}

View file

@ -44,7 +44,7 @@ private[sql] case class InsertIntoDataSource(
overwrite: Boolean)
extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[InternalRow] = {
override def run(sqlContext: SQLContext): Seq[Row] = {
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
val data = DataFrame(sqlContext, query)
// Apply the schema of the existing table to the new data.
@ -54,7 +54,7 @@ private[sql] case class InsertIntoDataSource(
// Invalidate the cache.
sqlContext.cacheManager.invalidateCache(logicalRelation)
Seq.empty[InternalRow]
Seq.empty[Row]
}
}
@ -86,7 +86,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
mode: SaveMode)
extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[InternalRow] = {
override def run(sqlContext: SQLContext): Seq[Row] = {
require(
relation.paths.length == 1,
s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}")

View file

@ -20,7 +20,6 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
case class ReflectData(
stringField: String,
@ -128,16 +127,16 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
Seq(data).toDF().registerTempTable("reflectComplexData")
assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head ===
new GenericRow(Array[Any](
Row(
Seq(1, 2, 3),
Seq(1, 2, null),
Map(1 -> 10L, 2 -> 20L),
Map(1 -> 10L, 2 -> 20L, 3 -> null),
new GenericRow(Array[Any](
Row(
Seq(10, 20, 30),
Seq(10, 20, null),
Map(10 -> 100L, 20 -> 200L),
Map(10 -> 100L, 20 -> 200L, 30 -> null),
new GenericRow(Array[Any](null, "abc")))))))
Row(null, "abc"))))
}
}

View file

@ -62,7 +62,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
override def buildScan(): RDD[Row] = {
sqlContext.sparkContext.parallelize(from to to).map { e =>
InternalRow(UTF8String.fromString(s"people$e"), e * 2)
InternalRow(UTF8String.fromString(s"people$e"), e * 2): Row
}
}
}

View file

@ -90,8 +90,8 @@ case class AllDataTypesScan(
Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
Map(i -> UTF8String.fromString(i.toString)),
Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
Row(i, UTF8String.fromString(i.toString)),
Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
InternalRow(i, UTF8String.fromString(i.toString)),
InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
}
}

View file

@ -336,9 +336,8 @@ private[hive] trait HiveInspectors {
// currently, hive doesn't provide the ConstantStructObjectInspector
case si: StructObjectInspector =>
val allRefs = si.getAllStructFieldRefs
new GenericRow(
allRefs.map(r =>
unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector)).toArray)
InternalRow.fromSeq(
allRefs.map(r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector)))
}

View file

@ -34,6 +34,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
@ -356,7 +357,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging {
(value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value))
case oi: HiveVarcharObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) =>
row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue)
row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue))
case oi: HiveDecimalObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) =>
row.update(ordinal, HiveShim.toCatalystDecimal(oi, value))

View file

@ -17,13 +17,11 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.catalyst.expressions.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.client.{HiveTable, HiveColumn}
import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, HiveMetastoreTypes}
import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable}
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation}
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
/**
* Create table and insert the query result into it.
@ -42,11 +40,11 @@ case class CreateTableAsSelect(
def database: String = tableDesc.database
def tableName: String = tableDesc.name
override def run(sqlContext: SQLContext): Seq[InternalRow] = {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
lazy val metastoreRelation: MetastoreRelation = {
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.TextInputFormat
@ -89,7 +87,7 @@ case class CreateTableAsSelect(
hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd
}
Seq.empty[InternalRow]
Seq.empty[Row]
}
override def argString: String = {

View file

@ -21,10 +21,10 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.MetastoreRelation
import org.apache.spark.sql.{Row, SQLContext}
/**
* Implementation for "describe [extended] table".
@ -35,7 +35,7 @@ case class DescribeHiveTableCommand(
override val output: Seq[Attribute],
isExtended: Boolean) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[InternalRow] = {
override def run(sqlContext: SQLContext): Seq[Row] = {
// Trying to mimic the format of Hive's output. But not exactly the same.
var results: Seq[(String, String, String)] = Nil
@ -57,7 +57,7 @@ case class DescribeHiveTableCommand(
}
results.map { case (name, dataType, comment) =>
InternalRow(name, dataType, comment)
Row(name, dataType, comment)
}
}
}

View file

@ -17,11 +17,11 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, InternalRow}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{Row, SQLContext}
private[hive]
case class HiveNativeCommand(sql: String) extends RunnableCommand {
@ -29,6 +29,6 @@ case class HiveNativeCommand(sql: String) extends RunnableCommand {
override def output: Seq[AttributeReference] =
Seq(AttributeReference("result", StringType, nullable = false)())
override def run(sqlContext: SQLContext): Seq[InternalRow] =
sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(InternalRow(_))
override def run(sqlContext: SQLContext): Seq[Row] =
sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_))
}

View file

@ -123,7 +123,7 @@ case class HiveTableScan(
// Only partitioned values are needed here, since the predicate has already been bound to
// partition key attribute references.
val row = new GenericRow(castedValues.toArray)
val row = InternalRow.fromSeq(castedValues)
shouldKeep.eval(row).asInstanceOf[Boolean]
}
}

View file

@ -129,11 +129,11 @@ case class ScriptTransformation(
val prevLine = curLine
curLine = reader.readLine()
if (!ioschema.schemaLess) {
new GenericRow(CatalystTypeConverters.convertToCatalyst(
new GenericInternalRow(CatalystTypeConverters.convertToCatalyst(
prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")))
.asInstanceOf[Array[Any]])
} else {
new GenericRow(CatalystTypeConverters.convertToCatalyst(
new GenericInternalRow(CatalystTypeConverters.convertToCatalyst(
prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2))
.asInstanceOf[Array[Any]])
}
@ -167,7 +167,8 @@ case class ScriptTransformation(
outputStream.write(data)
} else {
val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi)
val writable = inputSerde.serialize(
row.asInstanceOf[GenericInternalRow].values, inputSoi)
prepareWritable(writable).write(dataOutputStream)
}
}

View file

@ -17,15 +17,14 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@ -39,9 +38,9 @@ import org.apache.spark.util.Utils
private[hive]
case class AnalyzeTable(tableName: String) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[InternalRow] = {
override def run(sqlContext: SQLContext): Seq[Row] = {
sqlContext.asInstanceOf[HiveContext].analyze(tableName)
Seq.empty[InternalRow]
Seq.empty[Row]
}
}
@ -53,7 +52,7 @@ case class DropTable(
tableName: String,
ifExists: Boolean) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[InternalRow] = {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
val ifExistsClause = if (ifExists) "IF EXISTS " else ""
try {
@ -70,7 +69,7 @@ case class DropTable(
hiveContext.invalidateTable(tableName)
hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName")
hiveContext.catalog.unregisterTable(Seq(tableName))
Seq.empty[InternalRow]
Seq.empty[Row]
}
}
@ -83,7 +82,7 @@ case class AddJar(path: String) extends RunnableCommand {
schema.toAttributes
}
override def run(sqlContext: SQLContext): Seq[InternalRow] = {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
val currentClassLoader = Utils.getContextOrSparkClassLoader
@ -105,18 +104,18 @@ case class AddJar(path: String) extends RunnableCommand {
// Add jar to executors
hiveContext.sparkContext.addJar(path)
Seq(InternalRow(0))
Seq(Row(0))
}
}
private[hive]
case class AddFile(path: String) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[InternalRow] = {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
hiveContext.runSqlHive(s"ADD FILE $path")
hiveContext.sparkContext.addFile(path)
Seq.empty[InternalRow]
Seq.empty[Row]
}
}
@ -129,12 +128,12 @@ case class CreateMetastoreDataSource(
allowExisting: Boolean,
managedIfNoPath: Boolean) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[InternalRow] = {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
if (hiveContext.catalog.tableExists(tableName :: Nil)) {
if (allowExisting) {
return Seq.empty[InternalRow]
return Seq.empty[Row]
} else {
throw new AnalysisException(s"Table $tableName already exists.")
}
@ -157,7 +156,7 @@ case class CreateMetastoreDataSource(
optionsWithPath,
isExternal)
Seq.empty[InternalRow]
Seq.empty[Row]
}
}
@ -170,7 +169,7 @@ case class CreateMetastoreDataSourceAsSelect(
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[InternalRow] = {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
var createMetastoreTable = false
var isExternal = true
@ -194,7 +193,7 @@ case class CreateMetastoreDataSourceAsSelect(
s"Or, if you are using SQL CREATE TABLE, you need to drop $tableName first.")
case SaveMode.Ignore =>
// Since the table already exists and the save mode is Ignore, we will just return.
return Seq.empty[InternalRow]
return Seq.empty[Row]
case SaveMode.Append =>
// Check if the specified data source match the data source of the existing table.
val resolved = ResolvedDataSource(
@ -259,6 +258,6 @@ case class CreateMetastoreDataSourceAsSelect(
// Refresh the cache of the table in the catalog.
hiveContext.refreshTable(tableName)
Seq.empty[InternalRow]
Seq.empty[Row]
}
}

View file

@ -190,7 +190,7 @@ private[sql] class OrcRelation(
filters: Array[Filter],
inputPaths: Array[FileStatus]): RDD[Row] = {
val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes
OrcTableScan(output, this, filters, inputPaths).execute()
OrcTableScan(output, this, filters, inputPaths).execute().map(_.asInstanceOf[Row])
}
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
@ -234,13 +234,13 @@ private[orc] case class OrcTableScan(
HiveShim.appendReadColumns(conf, sortedIds, sortedNames)
}
// Transform all given raw `Writable`s into `Row`s.
// Transform all given raw `Writable`s into `InternalRow`s.
private def fillObject(
path: String,
conf: Configuration,
iterator: Iterator[Writable],
nonPartitionKeyAttrs: Seq[(Attribute, Int)],
mutableRow: MutableRow): Iterator[Row] = {
mutableRow: MutableRow): Iterator[InternalRow] = {
val deserializer = new OrcSerde
val soi = OrcFileOperator.getObjectInspector(path, Some(conf))
val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map {
@ -261,11 +261,11 @@ private[orc] case class OrcTableScan(
}
i += 1
}
mutableRow: Row
mutableRow: InternalRow
}
}
def execute(): RDD[Row] = {
def execute(): RDD[InternalRow] = {
val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
val conf = job.getConfiguration

View file

@ -202,9 +202,9 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
val dt = StructType(dataTypes.zipWithIndex.map {
case (t, idx) => StructField(s"c_$idx", t)
})
val inspector = toInspector(dt)
checkValues(row,
unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[InternalRow])
unwrap(wrap(InternalRow.fromSeq(row), inspector), inspector).asInstanceOf[InternalRow])
checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
}