[SPARK-3572] [SQL] Internal API for User-Defined Types

This PR adds User-Defined Types (UDTs) to SQL. It is a precursor to using SchemaRDD as a Dataset for the new MLlib API. Currently, the UDT API is private since there is incomplete support (e.g., no Java or Python support yet).

Author: Joseph K. Bradley <joseph@databricks.com>
Author: Michael Armbrust <michael@databricks.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #3063 from marmbrus/udts and squashes the following commits:

7ccfc0d [Michael Armbrust] remove println
46a3aee [Michael Armbrust] Slightly easier to read test output.
6cc434d [Michael Armbrust] Recursively convert rows.
e369b91 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udts
15c10a6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into sql-udt2
f3c72fe [Joseph K. Bradley] Fixing merge
e13cd8a [Joseph K. Bradley] Removed Vector UDTs
5817b2b [Joseph K. Bradley] style edits
30ce5b2 [Joseph K. Bradley] updates based on code review
d063380 [Joseph K. Bradley] Cleaned up Java UDT Suite, and added warning about element ordering when creating schema from Java Bean
a571bb6 [Joseph K. Bradley] Removed old UDT code (registry and Java UDTs).  Cleaned up other code.  Extended JavaUserDefinedTypeSuite
6fddc1c [Joseph K. Bradley] Made MyLabeledPoint into a Java Bean
20630bc [Joseph K. Bradley] fixed scalastyle
fa86b20 [Joseph K. Bradley] Removed Java UserDefinedType, and made UDTs private[spark] for now
8de957c [Joseph K. Bradley] Modified UserDefinedType to store Java class of user type so that registerUDT takes only the udt argument.
8b242ea [Joseph K. Bradley] Fixed merge error after last merge.  Note: Last merge commit also removed SQL UDT examples from mllib.
7f29656 [Joseph K. Bradley] Moved udt case to top of all matches.  Small cleanups
b028675 [Xiangrui Meng] allow any type in UDT
4500d8a [Xiangrui Meng] update example code
87264a5 [Xiangrui Meng] remove debug code
3143ac3 [Xiangrui Meng] remove unnecessary changes
cfbc321 [Xiangrui Meng] support UDT in parquet
db16139 [Joseph K. Bradley] Added more doc for UserDefinedType.  Removed unused code in Suite
759af7a [Joseph K. Bradley] Added more doc to UserDefineType
63626a4 [Joseph K. Bradley] Updated ScalaReflectionsSuite per @marmbrus suggestions
51e5282 [Joseph K. Bradley] fixed 1 test
f025035 [Joseph K. Bradley] Cleanups before PR.  Added new tests
85872f6 [Michael Armbrust] Allow schema calculation to be lazy, but ensure its available on executors.
dff99d6 [Joseph K. Bradley] Added UDTs for Vectors in MLlib, plus DatasetExample using the UDTs
cd60cb4 [Joseph K. Bradley] Trying to get other SQL tests to run
34a5831 [Joseph K. Bradley] Added MLlib dependency on SQL.
e1f7b9c [Joseph K. Bradley] blah
2f40c02 [Joseph K. Bradley] renamed UDT types
3579035 [Joseph K. Bradley] udt annotation now working
b226b9e [Joseph K. Bradley] Changing UDT to annotation
fea04af [Joseph K. Bradley] more cleanups
964b32e [Joseph K. Bradley] some cleanups
893ee4c [Joseph K. Bradley] udt finallly working
50f9726 [Joseph K. Bradley] udts
04303c9 [Joseph K. Bradley] udts
39f8707 [Joseph K. Bradley] removed old udt suite
273ac96 [Joseph K. Bradley] basic UDT is working, but deserialization has yet to be done
8bebf24 [Joseph K. Bradley] commented out convertRowToScala for debugging
53de70f [Joseph K. Bradley] more udts...
982c035 [Joseph K. Bradley] still working on UDTs
19b2f60 [Joseph K. Bradley] still working on UDTs
0eaeb81 [Joseph K. Bradley] Still working on UDTs
105c5a3 [Joseph K. Bradley] Adding UserDefinedType to SQL, not done yet.
This commit is contained in:
Joseph K. Bradley 2014-11-02 17:55:55 -08:00 committed by Michael Armbrust
parent 2ebd1df3f1
commit ebd6480587
24 changed files with 621 additions and 147 deletions

View file

@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
import java.sql.{Date, Timestamp}
import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
@ -35,25 +37,46 @@ object ScalaReflection {
case class Schema(dataType: DataType, nullable: Boolean)
/** Converts Scala objects to catalyst rows / types */
def convertToCatalyst(a: Any): Any = a match {
case o: Option[_] => o.map(convertToCatalyst).orNull
case s: Seq[_] => s.map(convertToCatalyst)
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
case d: BigDecimal => Decimal(d)
case other => other
/**
* Converts Scala objects to catalyst rows / types.
* Note: This is always called after schemaFor has been called.
* This ordering is important for UDT registration.
*/
def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
// Check UDT first since UDTs can override other types
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
case (p: Product, structType: StructType) =>
new GenericRow(
p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) =>
convertToCatalyst(elem, field.dataType)
}.toArray)
case (d: BigDecimal, _) => Decimal(d)
case (other, _) => other
}
/** Converts Catalyst types used internally in rows to standard Scala types */
def convertToScala(a: Any): Any = a match {
case s: Seq[_] => s.map(convertToScala)
case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) }
case d: Decimal => d.toBigDecimal
case other => other
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
// Check UDT first since UDTs can override other types
case (d, udt: UserDefinedType[_]) => udt.deserialize(d)
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (r: Row, s: StructType) => convertRowToScala(r, s)
case (d: Decimal, _: DecimalType) => d.toBigDecimal
case (other, _) => other
}
def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))
def convertRowToScala(r: Row, schema: StructType): Row = {
new GenericRow(
r.zip(schema.fields.map(_.dataType))
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
}
/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
@ -65,52 +88,64 @@ object ScalaReflection {
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
def schemaFor(tpe: `Type`): Schema = {
val className: String = tpe.erasure.typeSymbol.asClass.fullName
tpe match {
case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
// Note: We check for classIsLoadable above since Utils.classForName uses Java reflection,
// whereas className is from Scala reflection. This can make it hard to find classes
// in some cases, such as when a class is enclosed in an object (in which case
// Java appends a '$' to the object name but Scala does not).
val udt = Utils.classForName(className)
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Schema(udt, nullable = true)
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
}
}
def typeOfObject: PartialFunction[Any, DataType] = {

View file

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.annotation;
import java.lang.annotation.*;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.sql.catalyst.types.UserDefinedType;
/**
* ::DeveloperApi::
* A user-defined type which can be automatically recognized by a SQLContext and registered.
*
* WARNING: This annotation will only work if both Java and Scala reflection return the same class
* names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class
* is enclosed in an object (a singleton).
*
* WARNING: UDTs are currently only supported from Scala.
*/
// TODO: Should I used @Documented ?
@DeveloperApi
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface SQLUserDefinedType {
/**
* Returns an instance of the UserDefinedType which can serialize and deserialize the user
* class to and from Catalyst built-in types.
*/
Class<? extends UserDefinedType<?> > udt();
}

View file

@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.util.ClosureCleaner
/**
* User-defined function.
* @param dataType Return type of function.
*/
case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression])
extends Expression {
@ -347,6 +351,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
}
// scalastyle:on
ScalaReflection.convertToCatalyst(result)
ScalaReflection.convertToCatalyst(result, dataType)
}
}

View file

@ -29,11 +29,12 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row}
import org.apache.spark.sql.catalyst.types.decimal._
import org.apache.spark.sql.catalyst.util.Metadata
import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.types.decimal._
object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
@ -67,6 +68,11 @@ object DataType {
("fields", JArray(fields)),
("type", JString("struct"))) =>
StructType(fields.map(parseStructField))
case JSortedObject(
("class", JString(udtClass)),
("type", JString("udt"))) =>
Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
}
private def parseStructField(json: JValue): StructField = json match {
@ -342,6 +348,7 @@ object FractionalType {
case _ => false
}
}
abstract class FractionalType extends NumericType {
private[sql] val fractional: Fractional[JvmType]
private[sql] val asIntegral: Integral[JvmType]
@ -565,3 +572,45 @@ case class MapType(
("valueType" -> valueType.jsonValue) ~
("valueContainsNull" -> valueContainsNull)
}
/**
* ::DeveloperApi::
* The data type for User Defined Types (UDTs).
*
* This interface allows a user to make their own classes more interoperable with SparkSQL;
* e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create
* a SchemaRDD which has class X in the schema.
*
* For SparkSQL to recognize UDTs, the UDT must be annotated with
* [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]].
*
* The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD.
* The conversion via `deserialize` occurs when reading from a `SchemaRDD`.
*/
@DeveloperApi
abstract class UserDefinedType[UserType] extends DataType with Serializable {
/** Underlying storage type for this UDT */
def sqlType: DataType
/**
* Convert the user type to a SQL datum
*
* TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
* where we need to convert Any to UserType.
*/
def serialize(obj: Any): Any
/** Convert a SQL datum to the user type */
def deserialize(datum: Any): UserType
override private[sql] def jsonValue: JValue = {
("type" -> "udt") ~
("class" -> this.getClass.getName)
}
/**
* Class object for the UserType
*/
def userClass: java.lang.Class[UserType]
}

View file

@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.types._
case class PrimitiveData(
@ -239,13 +240,17 @@ class ScalaReflectionSuite extends FunSuite {
test("convert PrimitiveData to catalyst") {
val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
assert(convertToCatalyst(data) === convertedData)
val dataType = schemaFor[PrimitiveData].dataType
assert(convertToCatalyst(data, dataType) === convertedData)
}
test("convert Option[Product] to catalyst") {
val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData))
val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData))
assert(convertToCatalyst(data) === convertedData)
val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
Some(primitiveData))
val dataType = schemaFor[OptionalData].dataType
val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true,
Row(1, 1, 1, 1, 1, 1, true))
assert(convertToCatalyst(data, dataType) === convertedData)
}
}

View file

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.api.java;
import java.io.Serializable;
import org.apache.spark.annotation.DeveloperApi;
/**
* ::DeveloperApi::
* The data type representing User-Defined Types (UDTs).
* UDTs may use any other DataType for an underlying representation.
*/
@DeveloperApi
public abstract class UserDefinedType<UserType> extends DataType implements Serializable {
protected UserDefinedType() { }
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
UserDefinedType<UserType> that = (UserDefinedType<UserType>) o;
return this.sqlType().equals(that.sqlType());
}
/** Underlying storage type for this UDT */
public abstract DataType sqlType();
/** Convert the user type to a SQL datum */
public abstract Object serialize(Object obj);
/** Convert a SQL datum to the user type */
public abstract UserType deserialize(Object datum);
/** Class object for the UserType */
public abstract Class<UserType> userClass();
}

View file

@ -107,8 +107,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = {
SparkPlan.currentContext.set(self)
new SchemaRDD(this,
LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self))
val attributeSeq = ScalaReflection.attributesFor[A]
val schema = StructType.fromAttributes(attributeSeq)
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self))
}
implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = {

View file

@ -17,26 +17,24 @@
package org.apache.spark.sql
import java.util.{Map => JMap, List => JList}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.storage.StorageLevel
import java.util.{List => JList}
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import net.razorvine.pickle.Pickler
import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext}
import org.apache.spark.annotation.{AlphaComponent, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.api.java.JavaSchemaRDD
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.storage.StorageLevel
/**
* :: AlphaComponent ::
@ -114,18 +112,22 @@ class SchemaRDD(
// =========================================================================================
override def compute(split: Partition, context: TaskContext): Iterator[Row] =
firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala)
firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema))
override def getPartitions: Array[Partition] = firstParent[Row].partitions
override protected def getDependencies: Seq[Dependency[_]] =
List(new OneToOneDependency(queryExecution.toRdd))
override protected def getDependencies: Seq[Dependency[_]] = {
schema // Force reification of the schema so it is available on executors.
/** Returns the schema of this SchemaRDD (represented by a [[StructType]]).
*
* @group schema
*/
def schema: StructType = queryExecution.analyzed.schema
List(new OneToOneDependency(queryExecution.toRdd))
}
/**
* Returns the schema of this SchemaRDD (represented by a [[StructType]]).
*
* @group schema
*/
lazy val schema: StructType = queryExecution.analyzed.schema
// =======================================================================
// Query DSL

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.LogicalRDD
* Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
*/
private[sql] trait SchemaRDDLike {
@transient val sqlContext: SQLContext
@transient def sqlContext: SQLContext
@transient val baseLogicalPlan: LogicalPlan
private[sql] def baseSchemaRDD: SchemaRDD

View file

@ -78,7 +78,7 @@ private[sql] trait UDFRegistration {
s"""
def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = {
def builder(e: Seq[Expression]) =
ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
"""
@ -87,112 +87,112 @@ private[sql] trait UDFRegistration {
// scalastyle:off
def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
// scalastyle:on

View file

@ -23,13 +23,14 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation}
import org.apache.spark.sql.types.util.DataTypeConversions
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType
import org.apache.spark.util.Utils
@ -91,9 +92,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
/**
* Applies a schema to an RDD of Java Beans.
*
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = {
val schema = getSchema(beanClass)
val attributeSeq = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.rdd.mapPartitions { iter =>
// BeanInfo is not serializable so we must rediscover it remotely for each partition.
@ -104,11 +108,13 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
iter.map { row =>
new GenericRow(
extractors.map(e => DataTypeConversions.convertJavaToCatalyst(e.invoke(row))).toArray[Any]
extractors.zip(attributeSeq).map { case (e, attr) =>
DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType)
}.toArray[Any]
): ScalaRow
}
}
new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext))
new JavaSchemaRDD(sqlContext, LogicalRDD(attributeSeq, rowRdd)(sqlContext))
}
/**
@ -195,14 +201,21 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
sqlContext.registerRDDAsTable(rdd.baseSchemaRDD, tableName)
}
/** Returns a Catalyst Schema for the given java bean class. */
/**
* Returns a Catalyst Schema for the given java bean class.
*/
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
val beanInfo = Introspector.getBeanInfo(beanClass)
// Note: The ordering of elements may differ from when the schema is inferred in Scala.
// This is because beanInfo.getPropertyDescriptors gives no guarantees about
// element ordering.
val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
fields.map { property =>
val (dataType, nullable) = property.getPropertyType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
case c: Class[_] if c == classOf[java.lang.String] =>
(org.apache.spark.sql.StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE =>

View file

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.api.java
import org.apache.spark.sql.catalyst.types.{UserDefinedType => ScalaUserDefinedType}
import org.apache.spark.sql.{DataType => ScalaDataType}
import org.apache.spark.sql.types.util.DataTypeConversions
/**
* Scala wrapper for a Java UserDefinedType
*/
private[sql] class JavaToScalaUDTWrapper[UserType](val javaUDT: UserDefinedType[UserType])
extends ScalaUserDefinedType[UserType] with Serializable {
/** Underlying storage type for this UDT */
val sqlType: ScalaDataType = DataTypeConversions.asScalaDataType(javaUDT.sqlType())
/** Convert the user type to a SQL datum */
def serialize(obj: Any): Any = javaUDT.serialize(obj)
/** Convert a SQL datum to the user type */
def deserialize(datum: Any): UserType = javaUDT.deserialize(datum)
val userClass: java.lang.Class[UserType] = javaUDT.userClass()
}
/**
* Java wrapper for a Scala UserDefinedType
*/
private[sql] class ScalaToJavaUDTWrapper[UserType](val scalaUDT: ScalaUserDefinedType[UserType])
extends UserDefinedType[UserType] with Serializable {
/** Underlying storage type for this UDT */
val sqlType: DataType = DataTypeConversions.asJavaDataType(scalaUDT.sqlType)
/** Convert the user type to a SQL datum */
def serialize(obj: Any): java.lang.Object = scalaUDT.serialize(obj).asInstanceOf[java.lang.Object]
/** Convert a SQL datum to the user type */
def deserialize(datum: Any): UserType = scalaUDT.deserialize(datum)
val userClass: java.lang.Class[UserType] = scalaUDT.userClass
}
private[sql] object UDTWrappers {
def wrapAsScala(udtType: UserDefinedType[_]): ScalaUserDefinedType[_] = {
udtType match {
case t: ScalaToJavaUDTWrapper[_] => t.scalaUDT
case _ => new JavaToScalaUDTWrapper(udtType)
}
}
def wrapAsJava(udtType: ScalaUserDefinedType[_]): UserDefinedType[_] = {
udtType match {
case t: JavaToScalaUDTWrapper[_] => t.javaUDT
case _ => new ScalaToJavaUDTWrapper(udtType)
}
}
}

View file

@ -19,29 +19,32 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataType, StructType, Row, SQLContext}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.ScalaReflection.Schema
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.types.UserDefinedType
/**
* :: DeveloperApi ::
*/
@DeveloperApi
object RDDConversions {
def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = {
data.mapPartitions { iterator =>
if (iterator.isEmpty) {
Iterator.empty
} else {
val bufferedIterator = iterator.buffered
val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
val schemaFields = schema.fields.toArray
bufferedIterator.map { r =>
var i = 0
while (i < mutableRow.length) {
mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
mutableRow(i) =
ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType)
i += 1
}

View file

@ -20,8 +20,6 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.{ScalaReflection, trees}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
@ -82,7 +80,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Runs this query returning the result as an array.
*/
def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect()
def executeCollect(): Array[Row] =
execute().map(ScalaReflection.convertRowToScala(_, schema)).collect()
protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {

View file

@ -280,7 +280,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val nPartitions = if (data.isEmpty) 1 else numPartitions
PhysicalRDD(
output,
RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil
RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions),
StructType.fromAttributes(output))) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>

View file

@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan)
partsScanned += numPartsToTry
}
buf.toArray.map(ScalaReflection.convertRowToScala)
buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
}
override def execute() = {
@ -179,8 +179,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
val ord = new RowOrdering(sortOrder, child.output)
// TODO: Is this copying for no reason?
override def executeCollect() =
child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala)
override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord)
.map(ScalaReflection.convertRowToScala(_, this.schema))
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.

View file

@ -77,6 +77,9 @@ private[sql] object CatalystConverter {
parent: CatalystConverter): Converter = {
val fieldType: DataType = field.dataType
fieldType match {
case udt: UserDefinedType[_] => {
createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent)
}
// For native JVM types we use a converter with native arrays
case ArrayType(elementType: NativeType, false) => {
new CatalystNativeArrayConverter(elementType, fieldIndex, parent)
@ -255,8 +258,8 @@ private[parquet] class CatalystGroupConverter(
schema,
index,
parent,
current=null,
buffer=new ArrayBuffer[Row](
current = null,
buffer = new ArrayBuffer[Row](
CatalystArrayConverter.INITIAL_ARRAY_SIZE))
/**
@ -301,7 +304,7 @@ private[parquet] class CatalystGroupConverter(
override def end(): Unit = {
if (!isRootConverter) {
assert(current!=null) // there should be no empty groups
assert(current != null) // there should be no empty groups
buffer.append(new GenericRow(current.toArray))
parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]]))
}
@ -358,7 +361,7 @@ private[parquet] class CatalystPrimitiveRowConverter(
override def end(): Unit = {}
// Overriden here to avoid auto-boxing for primitive types
// Overridden here to avoid auto-boxing for primitive types
override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit =
current.setBoolean(fieldIndex, value)
@ -533,7 +536,7 @@ private[parquet] class CatalystNativeArrayConverter(
override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit =
throw new UnsupportedOperationException
// Overriden here to avoid auto-boxing for primitive types
// Overridden here to avoid auto-boxing for primitive types
override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = {
checkGrowBuffer()
buffer(elements) = value.asInstanceOf[NativeType]

View file

@ -20,7 +20,6 @@ package org.apache.spark.sql.parquet
import java.util.{HashMap => JHashMap}
import org.apache.hadoop.conf.Configuration
import org.apache.spark.sql.catalyst.types.decimal.Decimal
import parquet.column.ParquetProperties
import parquet.hadoop.ParquetOutputFormat
import parquet.hadoop.api.ReadSupport.ReadContext
@ -31,6 +30,7 @@ import parquet.schema.MessageType
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Row}
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.types.decimal.Decimal
/**
* A `parquet.io.api.RecordMaterializer` for Rows.
@ -174,6 +174,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
if (value != null) {
schema match {
case t: UserDefinedType[_] => writeValue(t.sqlType, value)
case t @ ArrayType(_, _) => writeArray(
t,
value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])

View file

@ -290,6 +290,9 @@ private[parquet] object ParquetTypesConverter extends Logging {
builder.named(name)
}.getOrElse {
ctype match {
case udt: UserDefinedType[_] => {
fromDataType(udt.sqlType, name, nullable, inArray)
}
case ArrayType(elementType, false) => {
val parquetElementType = fromDataType(
elementType,

View file

@ -17,12 +17,16 @@
package org.apache.spark.sql.types.util
import scala.collection.JavaConverters._
import org.apache.spark.sql._
import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder}
import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField,
MetadataBuilder => JMetaDataBuilder, UDTWrappers, JavaToScalaUDTWrapper}
import org.apache.spark.sql.api.java.{DecimalType => JDecimalType}
import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.UserDefinedType
import scala.collection.JavaConverters._
protected[sql] object DataTypeConversions {
@ -41,6 +45,9 @@ protected[sql] object DataTypeConversions {
* Returns the equivalent DataType in Java for the given DataType in Scala.
*/
def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match {
case udtType: UserDefinedType[_] =>
UDTWrappers.wrapAsJava(udtType)
case StringType => JDataType.StringType
case BinaryType => JDataType.BinaryType
case BooleanType => JDataType.BooleanType
@ -80,6 +87,9 @@ protected[sql] object DataTypeConversions {
* Returns the equivalent DataType in Scala for the given DataType in Java.
*/
def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match {
case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] =>
UDTWrappers.wrapAsScala(udtType)
case stringType: org.apache.spark.sql.api.java.StringType =>
StringType
case binaryType: org.apache.spark.sql.api.java.BinaryType =>
@ -121,9 +131,11 @@ protected[sql] object DataTypeConversions {
}
/** Converts Java objects to catalyst rows / types */
def convertJavaToCatalyst(a: Any): Any = a match {
case d: java.math.BigDecimal => Decimal(BigDecimal(d))
case other => other
def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type
case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d))
case (d: java.math.BigDecimal, _) => BigDecimal(d)
case (other, _) => other
}
/** Converts Java objects to catalyst rows / types */

View file

@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.api.java;
import java.io.Serializable;
import java.util.*;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.MyDenseVector;
import org.apache.spark.sql.MyLabeledPoint;
public class JavaUserDefinedTypeSuite implements Serializable {
private transient JavaSparkContext javaCtx;
private transient JavaSQLContext javaSqlCtx;
@Before
public void setUp() {
javaCtx = new JavaSparkContext("local", "JavaUserDefinedTypeSuite");
javaSqlCtx = new JavaSQLContext(javaCtx);
}
@After
public void tearDown() {
javaCtx.stop();
javaCtx = null;
javaSqlCtx = null;
}
@Test
public void useScalaUDT() {
List<MyLabeledPoint> points = Arrays.asList(
new MyLabeledPoint(1.0, new MyDenseVector(new double[]{0.1, 1.0})),
new MyLabeledPoint(0.0, new MyDenseVector(new double[]{0.2, 2.0})));
JavaRDD<MyLabeledPoint> pointsRDD = javaCtx.parallelize(points);
JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(pointsRDD, MyLabeledPoint.class);
schemaRDD.registerTempTable("points");
List<Row> actualLabelRows = javaSqlCtx.sql("SELECT label FROM points").collect();
List<Double> actualLabels = new LinkedList<Double>();
for (Row r : actualLabelRows) {
actualLabels.add(r.getDouble(0));
}
for (MyLabeledPoint lp : points) {
Assert.assertTrue(actualLabels.contains(lp.label()));
}
List<Row> actualFeatureRows = javaSqlCtx.sql("SELECT features FROM points").collect();
List<MyDenseVector> actualFeatures = new LinkedList<MyDenseVector>();
for (Row r : actualFeatureRows) {
actualFeatures.add((MyDenseVector)r.get(0));
}
for (MyLabeledPoint lp : points) {
Assert.assertTrue(actualFeatures.contains(lp.features()));
}
List<Row> actual = javaSqlCtx.sql("SELECT label, features FROM points").collect();
List<MyLabeledPoint> actualPoints =
new LinkedList<MyLabeledPoint>();
for (Row r : actual) {
actualPoints.add(new MyLabeledPoint(r.getDouble(0), (MyDenseVector)r.get(1)));
}
for (MyLabeledPoint lp : points) {
Assert.assertTrue(actualPoints.contains(lp));
}
}
}

View file

@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.test.TestSQLContext._
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
override def equals(other: Any): Boolean = other match {
case v: MyDenseVector =>
java.util.Arrays.equals(this.data, v.data)
case _ => false
}
}
@BeanInfo
private[sql] case class MyLabeledPoint(
@BeanProperty label: Double,
@BeanProperty features: MyDenseVector)
private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
override def serialize(obj: Any): Seq[Double] = {
obj match {
case features: MyDenseVector =>
features.data.toSeq
}
}
override def deserialize(datum: Any): MyDenseVector = {
datum match {
case data: Seq[_] =>
new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray)
}
}
override def userClass = classOf[MyDenseVector]
}
class UserDefinedTypeSuite extends QueryTest {
test("register user type: MyDenseVector for MyLabeledPoint") {
val points = Seq(
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points)
val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
assert(labelsArrays.contains(1.0))
assert(labelsArrays.contains(0.0))
val features: RDD[MyDenseVector] =
pointsRDD.select('features).map { case Row(v: MyDenseVector) => v }
val featuresArrays: Array[MyDenseVector] = features.collect()
assert(featuresArrays.size === 2)
assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0))))
assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0))))
}
}

View file

@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.{Row, SQLConf, QueryTest}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
@ -233,8 +232,8 @@ class JsonSuite extends QueryTest {
StructField("field2", StringType, true) ::
StructField("field3", StringType, true) :: Nil), false), true) ::
StructField("struct", StructType(
StructField("field1", BooleanType, true) ::
StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
StructField("field1", BooleanType, true) ::
StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(IntegerType, false), true) ::
StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil)
@ -292,8 +291,8 @@ class JsonSuite extends QueryTest {
// Access a struct and fields inside of it.
checkAnswer(
sql("select struct, struct.field1, struct.field2 from jsonTable"),
(
Seq(true, BigDecimal("92233720368547758070")),
Row(
Row(true, BigDecimal("92233720368547758070")),
true,
BigDecimal("92233720368547758070")) :: Nil
)

View file

@ -374,8 +374,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/** Extends QueryExecution with hive specific features. */
protected[sql] abstract class QueryExecution extends super.QueryExecution {
override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())
protected val primitiveTypes =
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
ShortType, DateType, TimestampType, BinaryType)
@ -433,7 +431,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
command.executeCollect().map(_.head.toString)
case other =>
val result: Seq[Seq[Any]] = toRdd.collect().toSeq
val result: Seq[Seq[Any]] = toRdd.map(_.copy()).collect().toSeq
// We need the types so we can output struct field names
val types = analyzed.output.map(_.dataType)
// Reformat to match hive tab delimited output.