[SPARK-3572] [SQL] Internal API for User-Defined Types
This PR adds User-Defined Types (UDTs) to SQL. It is a precursor to using SchemaRDD as a Dataset for the new MLlib API. Currently, the UDT API is private since there is incomplete support (e.g., no Java or Python support yet). Author: Joseph K. Bradley <joseph@databricks.com> Author: Michael Armbrust <michael@databricks.com> Author: Xiangrui Meng <meng@databricks.com> Closes #3063 from marmbrus/udts and squashes the following commits: 7ccfc0d [Michael Armbrust] remove println 46a3aee [Michael Armbrust] Slightly easier to read test output. 6cc434d [Michael Armbrust] Recursively convert rows. e369b91 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udts 15c10a6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into sql-udt2 f3c72fe [Joseph K. Bradley] Fixing merge e13cd8a [Joseph K. Bradley] Removed Vector UDTs 5817b2b [Joseph K. Bradley] style edits 30ce5b2 [Joseph K. Bradley] updates based on code review d063380 [Joseph K. Bradley] Cleaned up Java UDT Suite, and added warning about element ordering when creating schema from Java Bean a571bb6 [Joseph K. Bradley] Removed old UDT code (registry and Java UDTs). Cleaned up other code. Extended JavaUserDefinedTypeSuite 6fddc1c [Joseph K. Bradley] Made MyLabeledPoint into a Java Bean 20630bc [Joseph K. Bradley] fixed scalastyle fa86b20 [Joseph K. Bradley] Removed Java UserDefinedType, and made UDTs private[spark] for now 8de957c [Joseph K. Bradley] Modified UserDefinedType to store Java class of user type so that registerUDT takes only the udt argument. 8b242ea [Joseph K. Bradley] Fixed merge error after last merge. Note: Last merge commit also removed SQL UDT examples from mllib. 7f29656 [Joseph K. Bradley] Moved udt case to top of all matches. Small cleanups b028675 [Xiangrui Meng] allow any type in UDT 4500d8a [Xiangrui Meng] update example code 87264a5 [Xiangrui Meng] remove debug code 3143ac3 [Xiangrui Meng] remove unnecessary changes cfbc321 [Xiangrui Meng] support UDT in parquet db16139 [Joseph K. Bradley] Added more doc for UserDefinedType. Removed unused code in Suite 759af7a [Joseph K. Bradley] Added more doc to UserDefineType 63626a4 [Joseph K. Bradley] Updated ScalaReflectionsSuite per @marmbrus suggestions 51e5282 [Joseph K. Bradley] fixed 1 test f025035 [Joseph K. Bradley] Cleanups before PR. Added new tests 85872f6 [Michael Armbrust] Allow schema calculation to be lazy, but ensure its available on executors. dff99d6 [Joseph K. Bradley] Added UDTs for Vectors in MLlib, plus DatasetExample using the UDTs cd60cb4 [Joseph K. Bradley] Trying to get other SQL tests to run 34a5831 [Joseph K. Bradley] Added MLlib dependency on SQL. e1f7b9c [Joseph K. Bradley] blah 2f40c02 [Joseph K. Bradley] renamed UDT types 3579035 [Joseph K. Bradley] udt annotation now working b226b9e [Joseph K. Bradley] Changing UDT to annotation fea04af [Joseph K. Bradley] more cleanups 964b32e [Joseph K. Bradley] some cleanups 893ee4c [Joseph K. Bradley] udt finallly working 50f9726 [Joseph K. Bradley] udts 04303c9 [Joseph K. Bradley] udts 39f8707 [Joseph K. Bradley] removed old udt suite 273ac96 [Joseph K. Bradley] basic UDT is working, but deserialization has yet to be done 8bebf24 [Joseph K. Bradley] commented out convertRowToScala for debugging 53de70f [Joseph K. Bradley] more udts... 982c035 [Joseph K. Bradley] still working on UDTs 19b2f60 [Joseph K. Bradley] still working on UDTs 0eaeb81 [Joseph K. Bradley] Still working on UDTs 105c5a3 [Joseph K. Bradley] Adding UserDefinedType to SQL, not done yet.
This commit is contained in:
parent
2ebd1df3f1
commit
ebd6480587
|
@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
|
|||
|
||||
import java.sql.{Date, Timestamp}
|
||||
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
|
||||
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
|
||||
import org.apache.spark.sql.catalyst.types._
|
||||
|
@ -35,25 +37,46 @@ object ScalaReflection {
|
|||
|
||||
case class Schema(dataType: DataType, nullable: Boolean)
|
||||
|
||||
/** Converts Scala objects to catalyst rows / types */
|
||||
def convertToCatalyst(a: Any): Any = a match {
|
||||
case o: Option[_] => o.map(convertToCatalyst).orNull
|
||||
case s: Seq[_] => s.map(convertToCatalyst)
|
||||
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
|
||||
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
|
||||
case d: BigDecimal => Decimal(d)
|
||||
case other => other
|
||||
/**
|
||||
* Converts Scala objects to catalyst rows / types.
|
||||
* Note: This is always called after schemaFor has been called.
|
||||
* This ordering is important for UDT registration.
|
||||
*/
|
||||
def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
|
||||
// Check UDT first since UDTs can override other types
|
||||
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
|
||||
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
|
||||
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
|
||||
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
|
||||
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
|
||||
}
|
||||
case (p: Product, structType: StructType) =>
|
||||
new GenericRow(
|
||||
p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) =>
|
||||
convertToCatalyst(elem, field.dataType)
|
||||
}.toArray)
|
||||
case (d: BigDecimal, _) => Decimal(d)
|
||||
case (other, _) => other
|
||||
}
|
||||
|
||||
/** Converts Catalyst types used internally in rows to standard Scala types */
|
||||
def convertToScala(a: Any): Any = a match {
|
||||
case s: Seq[_] => s.map(convertToScala)
|
||||
case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) }
|
||||
case d: Decimal => d.toBigDecimal
|
||||
case other => other
|
||||
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
|
||||
// Check UDT first since UDTs can override other types
|
||||
case (d, udt: UserDefinedType[_]) => udt.deserialize(d)
|
||||
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType))
|
||||
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
|
||||
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
|
||||
}
|
||||
case (r: Row, s: StructType) => convertRowToScala(r, s)
|
||||
case (d: Decimal, _: DecimalType) => d.toBigDecimal
|
||||
case (other, _) => other
|
||||
}
|
||||
|
||||
def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))
|
||||
def convertRowToScala(r: Row, schema: StructType): Row = {
|
||||
new GenericRow(
|
||||
r.zip(schema.fields.map(_.dataType))
|
||||
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
|
||||
}
|
||||
|
||||
/** Returns a Sequence of attributes for the given case class type. */
|
||||
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
|
||||
|
@ -65,52 +88,64 @@ object ScalaReflection {
|
|||
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])
|
||||
|
||||
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
|
||||
def schemaFor(tpe: `Type`): Schema = tpe match {
|
||||
case t if t <:< typeOf[Option[_]] =>
|
||||
val TypeRef(_, _, Seq(optType)) = t
|
||||
Schema(schemaFor(optType).dataType, nullable = true)
|
||||
case t if t <:< typeOf[Product] =>
|
||||
val formalTypeArgs = t.typeSymbol.asClass.typeParams
|
||||
val TypeRef(_, _, actualTypeArgs) = t
|
||||
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
|
||||
Schema(StructType(
|
||||
params.head.map { p =>
|
||||
val Schema(dataType, nullable) =
|
||||
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
|
||||
StructField(p.name.toString, dataType, nullable)
|
||||
}), nullable = true)
|
||||
// Need to decide if we actually need a special type here.
|
||||
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
|
||||
case t if t <:< typeOf[Array[_]] =>
|
||||
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
|
||||
case t if t <:< typeOf[Seq[_]] =>
|
||||
val TypeRef(_, _, Seq(elementType)) = t
|
||||
val Schema(dataType, nullable) = schemaFor(elementType)
|
||||
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
|
||||
case t if t <:< typeOf[Map[_,_]] =>
|
||||
val TypeRef(_, _, Seq(keyType, valueType)) = t
|
||||
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
|
||||
Schema(MapType(schemaFor(keyType).dataType,
|
||||
valueDataType, valueContainsNull = valueNullable), nullable = true)
|
||||
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
|
||||
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
|
||||
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
|
||||
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
|
||||
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
|
||||
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
|
||||
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
|
||||
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
|
||||
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
|
||||
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
|
||||
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
|
||||
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
|
||||
def schemaFor(tpe: `Type`): Schema = {
|
||||
val className: String = tpe.erasure.typeSymbol.asClass.fullName
|
||||
tpe match {
|
||||
case t if Utils.classIsLoadable(className) &&
|
||||
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
|
||||
// Note: We check for classIsLoadable above since Utils.classForName uses Java reflection,
|
||||
// whereas className is from Scala reflection. This can make it hard to find classes
|
||||
// in some cases, such as when a class is enclosed in an object (in which case
|
||||
// Java appends a '$' to the object name but Scala does not).
|
||||
val udt = Utils.classForName(className)
|
||||
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
|
||||
Schema(udt, nullable = true)
|
||||
case t if t <:< typeOf[Option[_]] =>
|
||||
val TypeRef(_, _, Seq(optType)) = t
|
||||
Schema(schemaFor(optType).dataType, nullable = true)
|
||||
case t if t <:< typeOf[Product] =>
|
||||
val formalTypeArgs = t.typeSymbol.asClass.typeParams
|
||||
val TypeRef(_, _, actualTypeArgs) = t
|
||||
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
|
||||
Schema(StructType(
|
||||
params.head.map { p =>
|
||||
val Schema(dataType, nullable) =
|
||||
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
|
||||
StructField(p.name.toString, dataType, nullable)
|
||||
}), nullable = true)
|
||||
// Need to decide if we actually need a special type here.
|
||||
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
|
||||
case t if t <:< typeOf[Array[_]] =>
|
||||
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
|
||||
case t if t <:< typeOf[Seq[_]] =>
|
||||
val TypeRef(_, _, Seq(elementType)) = t
|
||||
val Schema(dataType, nullable) = schemaFor(elementType)
|
||||
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
|
||||
case t if t <:< typeOf[Map[_, _]] =>
|
||||
val TypeRef(_, _, Seq(keyType, valueType)) = t
|
||||
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
|
||||
Schema(MapType(schemaFor(keyType).dataType,
|
||||
valueDataType, valueContainsNull = valueNullable), nullable = true)
|
||||
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
|
||||
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
|
||||
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
|
||||
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
|
||||
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
|
||||
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
|
||||
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
|
||||
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
|
||||
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
|
||||
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
|
||||
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
|
||||
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
|
||||
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
|
||||
}
|
||||
}
|
||||
|
||||
def typeOfObject: PartialFunction[Any, DataType] = {
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.annotation;
|
||||
|
||||
import java.lang.annotation.*;
|
||||
|
||||
import org.apache.spark.annotation.DeveloperApi;
|
||||
import org.apache.spark.sql.catalyst.types.UserDefinedType;
|
||||
|
||||
/**
|
||||
* ::DeveloperApi::
|
||||
* A user-defined type which can be automatically recognized by a SQLContext and registered.
|
||||
*
|
||||
* WARNING: This annotation will only work if both Java and Scala reflection return the same class
|
||||
* names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class
|
||||
* is enclosed in an object (a singleton).
|
||||
*
|
||||
* WARNING: UDTs are currently only supported from Scala.
|
||||
*/
|
||||
// TODO: Should I used @Documented ?
|
||||
@DeveloperApi
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Target(ElementType.TYPE)
|
||||
public @interface SQLUserDefinedType {
|
||||
|
||||
/**
|
||||
* Returns an instance of the UserDefinedType which can serialize and deserialize the user
|
||||
* class to and from Catalyst built-in types.
|
||||
*/
|
||||
Class<? extends UserDefinedType<?> > udt();
|
||||
}
|
|
@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection
|
|||
import org.apache.spark.sql.catalyst.types.DataType
|
||||
import org.apache.spark.util.ClosureCleaner
|
||||
|
||||
/**
|
||||
* User-defined function.
|
||||
* @param dataType Return type of function.
|
||||
*/
|
||||
case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression])
|
||||
extends Expression {
|
||||
|
||||
|
@ -347,6 +351,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
|
|||
}
|
||||
// scalastyle:on
|
||||
|
||||
ScalaReflection.convertToCatalyst(result)
|
||||
ScalaReflection.convertToCatalyst(result, dataType)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,11 +29,12 @@ import org.json4s.JsonAST.JValue
|
|||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.sql.catalyst.ScalaReflectionLock
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row}
|
||||
import org.apache.spark.sql.catalyst.types.decimal._
|
||||
import org.apache.spark.sql.catalyst.util.Metadata
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.sql.catalyst.types.decimal._
|
||||
|
||||
object DataType {
|
||||
def fromJson(json: String): DataType = parseDataType(parse(json))
|
||||
|
@ -67,6 +68,11 @@ object DataType {
|
|||
("fields", JArray(fields)),
|
||||
("type", JString("struct"))) =>
|
||||
StructType(fields.map(parseStructField))
|
||||
|
||||
case JSortedObject(
|
||||
("class", JString(udtClass)),
|
||||
("type", JString("udt"))) =>
|
||||
Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
|
||||
}
|
||||
|
||||
private def parseStructField(json: JValue): StructField = json match {
|
||||
|
@ -342,6 +348,7 @@ object FractionalType {
|
|||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
abstract class FractionalType extends NumericType {
|
||||
private[sql] val fractional: Fractional[JvmType]
|
||||
private[sql] val asIntegral: Integral[JvmType]
|
||||
|
@ -565,3 +572,45 @@ case class MapType(
|
|||
("valueType" -> valueType.jsonValue) ~
|
||||
("valueContainsNull" -> valueContainsNull)
|
||||
}
|
||||
|
||||
/**
|
||||
* ::DeveloperApi::
|
||||
* The data type for User Defined Types (UDTs).
|
||||
*
|
||||
* This interface allows a user to make their own classes more interoperable with SparkSQL;
|
||||
* e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create
|
||||
* a SchemaRDD which has class X in the schema.
|
||||
*
|
||||
* For SparkSQL to recognize UDTs, the UDT must be annotated with
|
||||
* [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]].
|
||||
*
|
||||
* The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD.
|
||||
* The conversion via `deserialize` occurs when reading from a `SchemaRDD`.
|
||||
*/
|
||||
@DeveloperApi
|
||||
abstract class UserDefinedType[UserType] extends DataType with Serializable {
|
||||
|
||||
/** Underlying storage type for this UDT */
|
||||
def sqlType: DataType
|
||||
|
||||
/**
|
||||
* Convert the user type to a SQL datum
|
||||
*
|
||||
* TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
|
||||
* where we need to convert Any to UserType.
|
||||
*/
|
||||
def serialize(obj: Any): Any
|
||||
|
||||
/** Convert a SQL datum to the user type */
|
||||
def deserialize(datum: Any): UserType
|
||||
|
||||
override private[sql] def jsonValue: JValue = {
|
||||
("type" -> "udt") ~
|
||||
("class" -> this.getClass.getName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Class object for the UserType
|
||||
*/
|
||||
def userClass: java.lang.Class[UserType]
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
|
|||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.Row
|
||||
import org.apache.spark.sql.catalyst.types._
|
||||
|
||||
case class PrimitiveData(
|
||||
|
@ -239,13 +240,17 @@ class ScalaReflectionSuite extends FunSuite {
|
|||
test("convert PrimitiveData to catalyst") {
|
||||
val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
|
||||
val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
|
||||
assert(convertToCatalyst(data) === convertedData)
|
||||
val dataType = schemaFor[PrimitiveData].dataType
|
||||
assert(convertToCatalyst(data, dataType) === convertedData)
|
||||
}
|
||||
|
||||
test("convert Option[Product] to catalyst") {
|
||||
val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
|
||||
val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData))
|
||||
val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData))
|
||||
assert(convertToCatalyst(data) === convertedData)
|
||||
val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
|
||||
Some(primitiveData))
|
||||
val dataType = schemaFor[OptionalData].dataType
|
||||
val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true,
|
||||
Row(1, 1, 1, 1, 1, 1, true))
|
||||
assert(convertToCatalyst(data, dataType) === convertedData)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.api.java;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import org.apache.spark.annotation.DeveloperApi;
|
||||
|
||||
/**
|
||||
* ::DeveloperApi::
|
||||
* The data type representing User-Defined Types (UDTs).
|
||||
* UDTs may use any other DataType for an underlying representation.
|
||||
*/
|
||||
@DeveloperApi
|
||||
public abstract class UserDefinedType<UserType> extends DataType implements Serializable {
|
||||
|
||||
protected UserDefinedType() { }
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
UserDefinedType<UserType> that = (UserDefinedType<UserType>) o;
|
||||
return this.sqlType().equals(that.sqlType());
|
||||
}
|
||||
|
||||
/** Underlying storage type for this UDT */
|
||||
public abstract DataType sqlType();
|
||||
|
||||
/** Convert the user type to a SQL datum */
|
||||
public abstract Object serialize(Object obj);
|
||||
|
||||
/** Convert a SQL datum to the user type */
|
||||
public abstract UserType deserialize(Object datum);
|
||||
|
||||
/** Class object for the UserType */
|
||||
public abstract Class<UserType> userClass();
|
||||
}
|
|
@ -107,8 +107,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
*/
|
||||
implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = {
|
||||
SparkPlan.currentContext.set(self)
|
||||
new SchemaRDD(this,
|
||||
LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self))
|
||||
val attributeSeq = ScalaReflection.attributesFor[A]
|
||||
val schema = StructType.fromAttributes(attributeSeq)
|
||||
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
|
||||
new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self))
|
||||
}
|
||||
|
||||
implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = {
|
||||
|
|
|
@ -17,26 +17,24 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.util.{Map => JMap, List => JList}
|
||||
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import java.util.{List => JList}
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import net.razorvine.pickle.Pickler
|
||||
|
||||
import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext}
|
||||
import org.apache.spark.annotation.{AlphaComponent, Experimental}
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.api.java.JavaSchemaRDD
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
/**
|
||||
* :: AlphaComponent ::
|
||||
|
@ -114,18 +112,22 @@ class SchemaRDD(
|
|||
// =========================================================================================
|
||||
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[Row] =
|
||||
firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala)
|
||||
firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema))
|
||||
|
||||
override def getPartitions: Array[Partition] = firstParent[Row].partitions
|
||||
|
||||
override protected def getDependencies: Seq[Dependency[_]] =
|
||||
List(new OneToOneDependency(queryExecution.toRdd))
|
||||
override protected def getDependencies: Seq[Dependency[_]] = {
|
||||
schema // Force reification of the schema so it is available on executors.
|
||||
|
||||
/** Returns the schema of this SchemaRDD (represented by a [[StructType]]).
|
||||
*
|
||||
* @group schema
|
||||
*/
|
||||
def schema: StructType = queryExecution.analyzed.schema
|
||||
List(new OneToOneDependency(queryExecution.toRdd))
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the schema of this SchemaRDD (represented by a [[StructType]]).
|
||||
*
|
||||
* @group schema
|
||||
*/
|
||||
lazy val schema: StructType = queryExecution.analyzed.schema
|
||||
|
||||
// =======================================================================
|
||||
// Query DSL
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.LogicalRDD
|
|||
* Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
|
||||
*/
|
||||
private[sql] trait SchemaRDDLike {
|
||||
@transient val sqlContext: SQLContext
|
||||
@transient def sqlContext: SQLContext
|
||||
@transient val baseLogicalPlan: LogicalPlan
|
||||
|
||||
private[sql] def baseSchemaRDD: SchemaRDD
|
||||
|
|
|
@ -78,7 +78,7 @@ private[sql] trait UDFRegistration {
|
|||
s"""
|
||||
def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) =
|
||||
ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
"""
|
||||
|
@ -87,112 +87,112 @@ private[sql] trait UDFRegistration {
|
|||
|
||||
// scalastyle:off
|
||||
def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
|
||||
def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
|
||||
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
|
||||
functionRegistry.registerFunction(name, builder)
|
||||
}
|
||||
// scalastyle:on
|
||||
|
|
|
@ -23,13 +23,14 @@ import org.apache.hadoop.conf.Configuration
|
|||
|
||||
import org.apache.spark.annotation.{DeveloperApi, Experimental}
|
||||
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
|
||||
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
|
||||
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
|
||||
import org.apache.spark.sql.execution.LogicalRDD
|
||||
import org.apache.spark.sql.json.JsonRDD
|
||||
import org.apache.spark.sql.parquet.ParquetRelation
|
||||
import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation}
|
||||
import org.apache.spark.sql.types.util.DataTypeConversions
|
||||
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
|
||||
import org.apache.spark.sql.parquet.ParquetRelation
|
||||
import org.apache.spark.sql.execution.LogicalRDD
|
||||
import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
@ -91,9 +92,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
|
|||
|
||||
/**
|
||||
* Applies a schema to an RDD of Java Beans.
|
||||
*
|
||||
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
|
||||
* SELECT * queries will return the columns in an undefined order.
|
||||
*/
|
||||
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = {
|
||||
val schema = getSchema(beanClass)
|
||||
val attributeSeq = getSchema(beanClass)
|
||||
val className = beanClass.getName
|
||||
val rowRdd = rdd.rdd.mapPartitions { iter =>
|
||||
// BeanInfo is not serializable so we must rediscover it remotely for each partition.
|
||||
|
@ -104,11 +108,13 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
|
|||
|
||||
iter.map { row =>
|
||||
new GenericRow(
|
||||
extractors.map(e => DataTypeConversions.convertJavaToCatalyst(e.invoke(row))).toArray[Any]
|
||||
extractors.zip(attributeSeq).map { case (e, attr) =>
|
||||
DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType)
|
||||
}.toArray[Any]
|
||||
): ScalaRow
|
||||
}
|
||||
}
|
||||
new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext))
|
||||
new JavaSchemaRDD(sqlContext, LogicalRDD(attributeSeq, rowRdd)(sqlContext))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -195,14 +201,21 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
|
|||
sqlContext.registerRDDAsTable(rdd.baseSchemaRDD, tableName)
|
||||
}
|
||||
|
||||
/** Returns a Catalyst Schema for the given java bean class. */
|
||||
/**
|
||||
* Returns a Catalyst Schema for the given java bean class.
|
||||
*/
|
||||
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
|
||||
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
|
||||
val beanInfo = Introspector.getBeanInfo(beanClass)
|
||||
|
||||
// Note: The ordering of elements may differ from when the schema is inferred in Scala.
|
||||
// This is because beanInfo.getPropertyDescriptors gives no guarantees about
|
||||
// element ordering.
|
||||
val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
|
||||
fields.map { property =>
|
||||
val (dataType, nullable) = property.getPropertyType match {
|
||||
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
|
||||
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
|
||||
case c: Class[_] if c == classOf[java.lang.String] =>
|
||||
(org.apache.spark.sql.StringType, true)
|
||||
case c: Class[_] if c == java.lang.Short.TYPE =>
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.api.java
|
||||
|
||||
import org.apache.spark.sql.catalyst.types.{UserDefinedType => ScalaUserDefinedType}
|
||||
import org.apache.spark.sql.{DataType => ScalaDataType}
|
||||
import org.apache.spark.sql.types.util.DataTypeConversions
|
||||
|
||||
/**
|
||||
* Scala wrapper for a Java UserDefinedType
|
||||
*/
|
||||
private[sql] class JavaToScalaUDTWrapper[UserType](val javaUDT: UserDefinedType[UserType])
|
||||
extends ScalaUserDefinedType[UserType] with Serializable {
|
||||
|
||||
/** Underlying storage type for this UDT */
|
||||
val sqlType: ScalaDataType = DataTypeConversions.asScalaDataType(javaUDT.sqlType())
|
||||
|
||||
/** Convert the user type to a SQL datum */
|
||||
def serialize(obj: Any): Any = javaUDT.serialize(obj)
|
||||
|
||||
/** Convert a SQL datum to the user type */
|
||||
def deserialize(datum: Any): UserType = javaUDT.deserialize(datum)
|
||||
|
||||
val userClass: java.lang.Class[UserType] = javaUDT.userClass()
|
||||
}
|
||||
|
||||
/**
|
||||
* Java wrapper for a Scala UserDefinedType
|
||||
*/
|
||||
private[sql] class ScalaToJavaUDTWrapper[UserType](val scalaUDT: ScalaUserDefinedType[UserType])
|
||||
extends UserDefinedType[UserType] with Serializable {
|
||||
|
||||
/** Underlying storage type for this UDT */
|
||||
val sqlType: DataType = DataTypeConversions.asJavaDataType(scalaUDT.sqlType)
|
||||
|
||||
/** Convert the user type to a SQL datum */
|
||||
def serialize(obj: Any): java.lang.Object = scalaUDT.serialize(obj).asInstanceOf[java.lang.Object]
|
||||
|
||||
/** Convert a SQL datum to the user type */
|
||||
def deserialize(datum: Any): UserType = scalaUDT.deserialize(datum)
|
||||
|
||||
val userClass: java.lang.Class[UserType] = scalaUDT.userClass
|
||||
}
|
||||
|
||||
private[sql] object UDTWrappers {
|
||||
|
||||
def wrapAsScala(udtType: UserDefinedType[_]): ScalaUserDefinedType[_] = {
|
||||
udtType match {
|
||||
case t: ScalaToJavaUDTWrapper[_] => t.scalaUDT
|
||||
case _ => new JavaToScalaUDTWrapper(udtType)
|
||||
}
|
||||
}
|
||||
|
||||
def wrapAsJava(udtType: ScalaUserDefinedType[_]): UserDefinedType[_] = {
|
||||
udtType match {
|
||||
case t: JavaToScalaUDTWrapper[_] => t.javaUDT
|
||||
case _ => new ScalaToJavaUDTWrapper(udtType)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -19,29 +19,32 @@ package org.apache.spark.sql.execution
|
|||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataType, StructType, Row, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection.Schema
|
||||
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.types.UserDefinedType
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
*/
|
||||
@DeveloperApi
|
||||
object RDDConversions {
|
||||
def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
|
||||
def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = {
|
||||
data.mapPartitions { iterator =>
|
||||
if (iterator.isEmpty) {
|
||||
Iterator.empty
|
||||
} else {
|
||||
val bufferedIterator = iterator.buffered
|
||||
val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
|
||||
|
||||
val schemaFields = schema.fields.toArray
|
||||
bufferedIterator.map { r =>
|
||||
var i = 0
|
||||
while (i < mutableRow.length) {
|
||||
mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
|
||||
mutableRow(i) =
|
||||
ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType)
|
||||
i += 1
|
||||
}
|
||||
|
||||
|
|
|
@ -20,8 +20,6 @@ package org.apache.spark.sql.execution
|
|||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
|
||||
import org.apache.spark.sql.SQLContext
|
||||
import org.apache.spark.sql.catalyst.{ScalaReflection, trees}
|
||||
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
|
||||
|
@ -82,7 +80,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
|
|||
/**
|
||||
* Runs this query returning the result as an array.
|
||||
*/
|
||||
def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect()
|
||||
def executeCollect(): Array[Row] =
|
||||
execute().map(ScalaReflection.convertRowToScala(_, schema)).collect()
|
||||
|
||||
protected def newProjection(
|
||||
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
|
||||
|
|
|
@ -280,7 +280,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
val nPartitions = if (data.isEmpty) 1 else numPartitions
|
||||
PhysicalRDD(
|
||||
output,
|
||||
RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil
|
||||
RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions),
|
||||
StructType.fromAttributes(output))) :: Nil
|
||||
case logical.Limit(IntegerLiteral(limit), child) =>
|
||||
execution.Limit(limit, planLater(child)) :: Nil
|
||||
case Unions(unionChildren) =>
|
||||
|
|
|
@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan)
|
|||
partsScanned += numPartsToTry
|
||||
}
|
||||
|
||||
buf.toArray.map(ScalaReflection.convertRowToScala)
|
||||
buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
|
||||
}
|
||||
|
||||
override def execute() = {
|
||||
|
@ -179,8 +179,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
|
|||
val ord = new RowOrdering(sortOrder, child.output)
|
||||
|
||||
// TODO: Is this copying for no reason?
|
||||
override def executeCollect() =
|
||||
child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala)
|
||||
override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord)
|
||||
.map(ScalaReflection.convertRowToScala(_, this.schema))
|
||||
|
||||
// TODO: Terminal split should be implemented differently from non-terminal split.
|
||||
// TODO: Pick num splits based on |limit|.
|
||||
|
|
|
@ -77,6 +77,9 @@ private[sql] object CatalystConverter {
|
|||
parent: CatalystConverter): Converter = {
|
||||
val fieldType: DataType = field.dataType
|
||||
fieldType match {
|
||||
case udt: UserDefinedType[_] => {
|
||||
createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent)
|
||||
}
|
||||
// For native JVM types we use a converter with native arrays
|
||||
case ArrayType(elementType: NativeType, false) => {
|
||||
new CatalystNativeArrayConverter(elementType, fieldIndex, parent)
|
||||
|
@ -255,8 +258,8 @@ private[parquet] class CatalystGroupConverter(
|
|||
schema,
|
||||
index,
|
||||
parent,
|
||||
current=null,
|
||||
buffer=new ArrayBuffer[Row](
|
||||
current = null,
|
||||
buffer = new ArrayBuffer[Row](
|
||||
CatalystArrayConverter.INITIAL_ARRAY_SIZE))
|
||||
|
||||
/**
|
||||
|
@ -301,7 +304,7 @@ private[parquet] class CatalystGroupConverter(
|
|||
|
||||
override def end(): Unit = {
|
||||
if (!isRootConverter) {
|
||||
assert(current!=null) // there should be no empty groups
|
||||
assert(current != null) // there should be no empty groups
|
||||
buffer.append(new GenericRow(current.toArray))
|
||||
parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]]))
|
||||
}
|
||||
|
@ -358,7 +361,7 @@ private[parquet] class CatalystPrimitiveRowConverter(
|
|||
|
||||
override def end(): Unit = {}
|
||||
|
||||
// Overriden here to avoid auto-boxing for primitive types
|
||||
// Overridden here to avoid auto-boxing for primitive types
|
||||
override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit =
|
||||
current.setBoolean(fieldIndex, value)
|
||||
|
||||
|
@ -533,7 +536,7 @@ private[parquet] class CatalystNativeArrayConverter(
|
|||
override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit =
|
||||
throw new UnsupportedOperationException
|
||||
|
||||
// Overriden here to avoid auto-boxing for primitive types
|
||||
// Overridden here to avoid auto-boxing for primitive types
|
||||
override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = {
|
||||
checkGrowBuffer()
|
||||
buffer(elements) = value.asInstanceOf[NativeType]
|
||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.spark.sql.parquet
|
|||
import java.util.{HashMap => JHashMap}
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.spark.sql.catalyst.types.decimal.Decimal
|
||||
import parquet.column.ParquetProperties
|
||||
import parquet.hadoop.ParquetOutputFormat
|
||||
import parquet.hadoop.api.ReadSupport.ReadContext
|
||||
|
@ -31,6 +30,7 @@ import parquet.schema.MessageType
|
|||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, Row}
|
||||
import org.apache.spark.sql.catalyst.types._
|
||||
import org.apache.spark.sql.catalyst.types.decimal.Decimal
|
||||
|
||||
/**
|
||||
* A `parquet.io.api.RecordMaterializer` for Rows.
|
||||
|
@ -174,6 +174,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
|
|||
private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
|
||||
if (value != null) {
|
||||
schema match {
|
||||
case t: UserDefinedType[_] => writeValue(t.sqlType, value)
|
||||
case t @ ArrayType(_, _) => writeArray(
|
||||
t,
|
||||
value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
|
||||
|
|
|
@ -290,6 +290,9 @@ private[parquet] object ParquetTypesConverter extends Logging {
|
|||
builder.named(name)
|
||||
}.getOrElse {
|
||||
ctype match {
|
||||
case udt: UserDefinedType[_] => {
|
||||
fromDataType(udt.sqlType, name, nullable, inArray)
|
||||
}
|
||||
case ArrayType(elementType, false) => {
|
||||
val parquetElementType = fromDataType(
|
||||
elementType,
|
||||
|
|
|
@ -17,12 +17,16 @@
|
|||
|
||||
package org.apache.spark.sql.types.util
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder}
|
||||
import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField,
|
||||
MetadataBuilder => JMetaDataBuilder, UDTWrappers, JavaToScalaUDTWrapper}
|
||||
import org.apache.spark.sql.api.java.{DecimalType => JDecimalType}
|
||||
import org.apache.spark.sql.catalyst.types.decimal.Decimal
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.types.UserDefinedType
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
protected[sql] object DataTypeConversions {
|
||||
|
||||
|
@ -41,6 +45,9 @@ protected[sql] object DataTypeConversions {
|
|||
* Returns the equivalent DataType in Java for the given DataType in Scala.
|
||||
*/
|
||||
def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match {
|
||||
case udtType: UserDefinedType[_] =>
|
||||
UDTWrappers.wrapAsJava(udtType)
|
||||
|
||||
case StringType => JDataType.StringType
|
||||
case BinaryType => JDataType.BinaryType
|
||||
case BooleanType => JDataType.BooleanType
|
||||
|
@ -80,6 +87,9 @@ protected[sql] object DataTypeConversions {
|
|||
* Returns the equivalent DataType in Scala for the given DataType in Java.
|
||||
*/
|
||||
def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match {
|
||||
case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] =>
|
||||
UDTWrappers.wrapAsScala(udtType)
|
||||
|
||||
case stringType: org.apache.spark.sql.api.java.StringType =>
|
||||
StringType
|
||||
case binaryType: org.apache.spark.sql.api.java.BinaryType =>
|
||||
|
@ -121,9 +131,11 @@ protected[sql] object DataTypeConversions {
|
|||
}
|
||||
|
||||
/** Converts Java objects to catalyst rows / types */
|
||||
def convertJavaToCatalyst(a: Any): Any = a match {
|
||||
case d: java.math.BigDecimal => Decimal(BigDecimal(d))
|
||||
case other => other
|
||||
def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
|
||||
case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type
|
||||
case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d))
|
||||
case (d: java.math.BigDecimal, _) => BigDecimal(d)
|
||||
case (other, _) => other
|
||||
}
|
||||
|
||||
/** Converts Java objects to catalyst rows / types */
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.api.java;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.MyDenseVector;
|
||||
import org.apache.spark.sql.MyLabeledPoint;
|
||||
|
||||
public class JavaUserDefinedTypeSuite implements Serializable {
|
||||
private transient JavaSparkContext javaCtx;
|
||||
private transient JavaSQLContext javaSqlCtx;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
javaCtx = new JavaSparkContext("local", "JavaUserDefinedTypeSuite");
|
||||
javaSqlCtx = new JavaSQLContext(javaCtx);
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
javaCtx.stop();
|
||||
javaCtx = null;
|
||||
javaSqlCtx = null;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void useScalaUDT() {
|
||||
List<MyLabeledPoint> points = Arrays.asList(
|
||||
new MyLabeledPoint(1.0, new MyDenseVector(new double[]{0.1, 1.0})),
|
||||
new MyLabeledPoint(0.0, new MyDenseVector(new double[]{0.2, 2.0})));
|
||||
JavaRDD<MyLabeledPoint> pointsRDD = javaCtx.parallelize(points);
|
||||
|
||||
JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(pointsRDD, MyLabeledPoint.class);
|
||||
schemaRDD.registerTempTable("points");
|
||||
|
||||
List<Row> actualLabelRows = javaSqlCtx.sql("SELECT label FROM points").collect();
|
||||
List<Double> actualLabels = new LinkedList<Double>();
|
||||
for (Row r : actualLabelRows) {
|
||||
actualLabels.add(r.getDouble(0));
|
||||
}
|
||||
for (MyLabeledPoint lp : points) {
|
||||
Assert.assertTrue(actualLabels.contains(lp.label()));
|
||||
}
|
||||
|
||||
List<Row> actualFeatureRows = javaSqlCtx.sql("SELECT features FROM points").collect();
|
||||
List<MyDenseVector> actualFeatures = new LinkedList<MyDenseVector>();
|
||||
for (Row r : actualFeatureRows) {
|
||||
actualFeatures.add((MyDenseVector)r.get(0));
|
||||
}
|
||||
for (MyLabeledPoint lp : points) {
|
||||
Assert.assertTrue(actualFeatures.contains(lp.features()));
|
||||
}
|
||||
|
||||
List<Row> actual = javaSqlCtx.sql("SELECT label, features FROM points").collect();
|
||||
List<MyLabeledPoint> actualPoints =
|
||||
new LinkedList<MyLabeledPoint>();
|
||||
for (Row r : actual) {
|
||||
actualPoints.add(new MyLabeledPoint(r.getDouble(0), (MyDenseVector)r.get(1)));
|
||||
}
|
||||
for (MyLabeledPoint lp : points) {
|
||||
Assert.assertTrue(actualPoints.contains(lp));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import scala.beans.{BeanInfo, BeanProperty}
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
|
||||
import org.apache.spark.sql.catalyst.types.UserDefinedType
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
|
||||
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
|
||||
private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
|
||||
override def equals(other: Any): Boolean = other match {
|
||||
case v: MyDenseVector =>
|
||||
java.util.Arrays.equals(this.data, v.data)
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
@BeanInfo
|
||||
private[sql] case class MyLabeledPoint(
|
||||
@BeanProperty label: Double,
|
||||
@BeanProperty features: MyDenseVector)
|
||||
|
||||
private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
|
||||
|
||||
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
|
||||
|
||||
override def serialize(obj: Any): Seq[Double] = {
|
||||
obj match {
|
||||
case features: MyDenseVector =>
|
||||
features.data.toSeq
|
||||
}
|
||||
}
|
||||
|
||||
override def deserialize(datum: Any): MyDenseVector = {
|
||||
datum match {
|
||||
case data: Seq[_] =>
|
||||
new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray)
|
||||
}
|
||||
}
|
||||
|
||||
override def userClass = classOf[MyDenseVector]
|
||||
}
|
||||
|
||||
class UserDefinedTypeSuite extends QueryTest {
|
||||
|
||||
test("register user type: MyDenseVector for MyLabeledPoint") {
|
||||
val points = Seq(
|
||||
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
|
||||
MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
|
||||
val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points)
|
||||
|
||||
val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v }
|
||||
val labelsArrays: Array[Double] = labels.collect()
|
||||
assert(labelsArrays.size === 2)
|
||||
assert(labelsArrays.contains(1.0))
|
||||
assert(labelsArrays.contains(0.0))
|
||||
|
||||
val features: RDD[MyDenseVector] =
|
||||
pointsRDD.select('features).map { case Row(v: MyDenseVector) => v }
|
||||
val featuresArrays: Array[MyDenseVector] = features.collect()
|
||||
assert(featuresArrays.size === 2)
|
||||
assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0))))
|
||||
assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0))))
|
||||
}
|
||||
}
|
|
@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.types._
|
|||
import org.apache.spark.sql.catalyst.types.decimal.Decimal
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
|
||||
import org.apache.spark.sql.QueryTest
|
||||
import org.apache.spark.sql.SQLConf
|
||||
import org.apache.spark.sql.{Row, SQLConf, QueryTest}
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
|
||||
|
@ -233,8 +232,8 @@ class JsonSuite extends QueryTest {
|
|||
StructField("field2", StringType, true) ::
|
||||
StructField("field3", StringType, true) :: Nil), false), true) ::
|
||||
StructField("struct", StructType(
|
||||
StructField("field1", BooleanType, true) ::
|
||||
StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
|
||||
StructField("field1", BooleanType, true) ::
|
||||
StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
|
||||
StructField("structWithArrayFields", StructType(
|
||||
StructField("field1", ArrayType(IntegerType, false), true) ::
|
||||
StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil)
|
||||
|
@ -292,8 +291,8 @@ class JsonSuite extends QueryTest {
|
|||
// Access a struct and fields inside of it.
|
||||
checkAnswer(
|
||||
sql("select struct, struct.field1, struct.field2 from jsonTable"),
|
||||
(
|
||||
Seq(true, BigDecimal("92233720368547758070")),
|
||||
Row(
|
||||
Row(true, BigDecimal("92233720368547758070")),
|
||||
true,
|
||||
BigDecimal("92233720368547758070")) :: Nil
|
||||
)
|
||||
|
|
|
@ -374,8 +374,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
|
|||
/** Extends QueryExecution with hive specific features. */
|
||||
protected[sql] abstract class QueryExecution extends super.QueryExecution {
|
||||
|
||||
override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())
|
||||
|
||||
protected val primitiveTypes =
|
||||
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
|
||||
ShortType, DateType, TimestampType, BinaryType)
|
||||
|
@ -433,7 +431,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
|
|||
command.executeCollect().map(_.head.toString)
|
||||
|
||||
case other =>
|
||||
val result: Seq[Seq[Any]] = toRdd.collect().toSeq
|
||||
val result: Seq[Seq[Any]] = toRdd.map(_.copy()).collect().toSeq
|
||||
// We need the types so we can output struct field names
|
||||
val types = analyzed.output.map(_.dataType)
|
||||
// Reformat to match hive tab delimited output.
|
||||
|
|
Loading…
Reference in a new issue