[SPARK-5445][SQL] Made DataFrame dsl usable in Java
Also removed the literal implicit transformation since it is pretty scary for API design. Instead, created a new lit method for creating literals. This doesn't break anything from a compatibility perspective because Literal was added two days ago. Author: Reynold Xin <rxin@databricks.com> Closes #4241 from rxin/df-docupdate and squashes the following commits: c0f4810 [Reynold Xin] Fix Python merge conflict. 094c7d7 [Reynold Xin] Minor style fix. Reset Python tests. 3c89f4a [Reynold Xin] Package. dfe6962 [Reynold Xin] Updated Python aggregate. 5dd4265 [Reynold Xin] Made dsl Java callable. 14b3c27 [Reynold Xin] Fix literal expression for symbols. 68b31cb [Reynold Xin] Literal. 4cfeb78 [Reynold Xin] [SPARK-5097][SQL] Address DataFrame code review feedback.
This commit is contained in:
parent
4ee79c71af
commit
5b9760de8d
|
@ -19,8 +19,7 @@ package org.apache.spark.examples.sql
|
|||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.sql.SQLContext
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.dsl.literals._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
|
||||
// One method for defining the schema of an RDD is to make a case class with the desired column
|
||||
// names and types.
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent
|
|||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|
|||
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.catalyst.dsl._
|
||||
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.spark.ml.param._
|
|||
import org.apache.spark.mllib.feature
|
||||
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.catalyst.dsl._
|
||||
import org.apache.spark.sql.types.{StructField, StructType}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.apache.spark.ml.{Estimator, Model}
|
|||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{Column, DataFrame}
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
|
||||
|
|
|
@ -931,7 +931,7 @@ def _parse_schema_abstract(s):
|
|||
|
||||
def _infer_schema_type(obj, dataType):
|
||||
"""
|
||||
Fill the dataType with types infered from obj
|
||||
Fill the dataType with types inferred from obj
|
||||
|
||||
>>> schema = _parse_schema_abstract("a b c d")
|
||||
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
|
||||
|
@ -2140,7 +2140,7 @@ class DataFrame(object):
|
|||
return Column(self._jdf.apply(name))
|
||||
raise AttributeError
|
||||
|
||||
def As(self, name):
|
||||
def alias(self, name):
|
||||
""" Alias the current DataFrame """
|
||||
return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx)
|
||||
|
||||
|
@ -2216,7 +2216,7 @@ class DataFrame(object):
|
|||
"""
|
||||
return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
|
||||
|
||||
def Except(self, other):
|
||||
def subtract(self, other):
|
||||
""" Return a new [[DataFrame]] containing rows in this frame
|
||||
but not in another frame.
|
||||
|
||||
|
@ -2234,7 +2234,7 @@ class DataFrame(object):
|
|||
|
||||
def addColumn(self, colName, col):
|
||||
""" Return a new [[DataFrame]] by adding a column. """
|
||||
return self.select('*', col.As(colName))
|
||||
return self.select('*', col.alias(colName))
|
||||
|
||||
def removeColumn(self, colName):
|
||||
raise NotImplemented
|
||||
|
@ -2342,7 +2342,7 @@ SCALA_METHOD_MAPPINGS = {
|
|||
|
||||
def _create_column_from_literal(literal):
|
||||
sc = SparkContext._active_spark_context
|
||||
return sc._jvm.Literal.apply(literal)
|
||||
return sc._jvm.org.apache.spark.sql.api.java.dsl.lit(literal)
|
||||
|
||||
|
||||
def _create_column_from_name(name):
|
||||
|
@ -2371,13 +2371,20 @@ def _unary_op(name):
|
|||
return _
|
||||
|
||||
|
||||
def _bin_op(name):
|
||||
""" Create a method for given binary operator """
|
||||
def _bin_op(name, pass_literal_through=False):
|
||||
""" Create a method for given binary operator
|
||||
|
||||
Keyword arguments:
|
||||
pass_literal_through -- whether to pass literal value directly through to the JVM.
|
||||
"""
|
||||
def _(self, other):
|
||||
if isinstance(other, Column):
|
||||
jc = other._jc
|
||||
else:
|
||||
jc = _create_column_from_literal(other)
|
||||
if pass_literal_through:
|
||||
jc = other
|
||||
else:
|
||||
jc = _create_column_from_literal(other)
|
||||
return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx)
|
||||
return _
|
||||
|
||||
|
@ -2458,10 +2465,10 @@ class Column(DataFrame):
|
|||
# __getattr__ = _bin_op("getField")
|
||||
|
||||
# string methods
|
||||
rlike = _bin_op("rlike")
|
||||
like = _bin_op("like")
|
||||
startswith = _bin_op("startsWith")
|
||||
endswith = _bin_op("endsWith")
|
||||
rlike = _bin_op("rlike", pass_literal_through=True)
|
||||
like = _bin_op("like", pass_literal_through=True)
|
||||
startswith = _bin_op("startsWith", pass_literal_through=True)
|
||||
endswith = _bin_op("endsWith", pass_literal_through=True)
|
||||
upper = _unary_op("upper")
|
||||
lower = _unary_op("lower")
|
||||
|
||||
|
@ -2487,7 +2494,7 @@ class Column(DataFrame):
|
|||
isNotNull = _unary_op("isNotNull")
|
||||
|
||||
# `as` is keyword
|
||||
def As(self, alias):
|
||||
def alias(self, alias):
|
||||
return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx)
|
||||
|
||||
def cast(self, dataType):
|
||||
|
@ -2501,15 +2508,14 @@ class Column(DataFrame):
|
|||
|
||||
|
||||
def _aggregate_func(name):
|
||||
""" Creat a function for aggregator by name"""
|
||||
""" Create a function for aggregator by name"""
|
||||
def _(col):
|
||||
sc = SparkContext._active_spark_context
|
||||
if isinstance(col, Column):
|
||||
jcol = col._jc
|
||||
else:
|
||||
jcol = _create_column_from_name(col)
|
||||
# FIXME: can not access dsl.min/max ...
|
||||
jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol)
|
||||
jc = getattr(sc._jvm.org.apache.spark.sql.api.java.dsl, name)(jcol)
|
||||
return Column(jc)
|
||||
return staticmethod(_)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql
|
|||
|
||||
import scala.language.implicitConversions
|
||||
|
||||
import org.apache.spark.sql.api.scala.dsl.lit
|
||||
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
|
||||
|
@ -55,11 +56,11 @@ class Column(
|
|||
val expr: Expression)
|
||||
extends DataFrame(sqlContext, plan) with ExpressionApi {
|
||||
|
||||
/** Turn a Catalyst expression into a `Column`. */
|
||||
/** Turns a Catalyst expression into a `Column`. */
|
||||
protected[sql] def this(expr: Expression) = this(None, None, expr)
|
||||
|
||||
/**
|
||||
* Create a new `Column` expression based on a column or attribute name.
|
||||
* Creates a new `Column` expression based on a column or attribute name.
|
||||
* The resolution of this is the same as SQL. For example:
|
||||
*
|
||||
* - "colName" becomes an expression selecting the column named "colName".
|
||||
|
@ -108,7 +109,7 @@ class Column(
|
|||
override def unary_~ : Column = BitwiseNot(expr)
|
||||
|
||||
/**
|
||||
* Invert a boolean expression, i.e. NOT.
|
||||
* Inversion of boolean expression, i.e. NOT.
|
||||
* {{
|
||||
* // Select rows that are not active (isActive === false)
|
||||
* df.select( !df("isActive") )
|
||||
|
@ -135,7 +136,7 @@ class Column(
|
|||
* df.select( df("colA".equalTo("Zaharia") )
|
||||
* }}}
|
||||
*/
|
||||
override def === (literal: Any): Column = this === Literal.anyToLiteral(literal)
|
||||
override def === (literal: Any): Column = this === lit(literal)
|
||||
|
||||
/**
|
||||
* Equality test with an expression.
|
||||
|
@ -175,7 +176,7 @@ class Column(
|
|||
* df.select( !(df("colA") === 15) )
|
||||
* }}}
|
||||
*/
|
||||
override def !== (literal: Any): Column = this !== Literal.anyToLiteral(literal)
|
||||
override def !== (literal: Any): Column = this !== lit(literal)
|
||||
|
||||
/**
|
||||
* Greater than an expression.
|
||||
|
@ -193,7 +194,7 @@ class Column(
|
|||
* people.select( people("age") > 21 )
|
||||
* }}}
|
||||
*/
|
||||
override def > (literal: Any): Column = this > Literal.anyToLiteral(literal)
|
||||
override def > (literal: Any): Column = this > lit(literal)
|
||||
|
||||
/**
|
||||
* Less than an expression.
|
||||
|
@ -211,7 +212,7 @@ class Column(
|
|||
* people.select( people("age") < 21 )
|
||||
* }}}
|
||||
*/
|
||||
override def < (literal: Any): Column = this < Literal.anyToLiteral(literal)
|
||||
override def < (literal: Any): Column = this < lit(literal)
|
||||
|
||||
/**
|
||||
* Less than or equal to an expression.
|
||||
|
@ -229,7 +230,7 @@ class Column(
|
|||
* people.select( people("age") <= 21 )
|
||||
* }}}
|
||||
*/
|
||||
override def <= (literal: Any): Column = this <= Literal.anyToLiteral(literal)
|
||||
override def <= (literal: Any): Column = this <= lit(literal)
|
||||
|
||||
/**
|
||||
* Greater than or equal to an expression.
|
||||
|
@ -247,20 +248,20 @@ class Column(
|
|||
* people.select( people("age") >= 21 )
|
||||
* }}}
|
||||
*/
|
||||
override def >= (literal: Any): Column = this >= Literal.anyToLiteral(literal)
|
||||
override def >= (literal: Any): Column = this >= lit(literal)
|
||||
|
||||
/**
|
||||
* Equality test with an expression that is safe for null values.
|
||||
*/
|
||||
override def <=> (other: Column): Column = other match {
|
||||
case null => EqualNullSafe(expr, Literal.anyToLiteral(null).expr)
|
||||
case null => EqualNullSafe(expr, lit(null).expr)
|
||||
case _ => EqualNullSafe(expr, other.expr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Equality test with a literal value that is safe for null values.
|
||||
*/
|
||||
override def <=> (literal: Any): Column = this <=> Literal.anyToLiteral(literal)
|
||||
override def <=> (literal: Any): Column = this <=> lit(literal)
|
||||
|
||||
/**
|
||||
* True if the current expression is null.
|
||||
|
@ -288,7 +289,7 @@ class Column(
|
|||
* people.select( people("inSchool") || true )
|
||||
* }}}
|
||||
*/
|
||||
override def || (literal: Boolean): Column = this || Literal.anyToLiteral(literal)
|
||||
override def || (literal: Boolean): Column = this || lit(literal)
|
||||
|
||||
/**
|
||||
* Boolean AND with an expression.
|
||||
|
@ -306,7 +307,7 @@ class Column(
|
|||
* people.select( people("inSchool") && true )
|
||||
* }}}
|
||||
*/
|
||||
override def && (literal: Boolean): Column = this && Literal.anyToLiteral(literal)
|
||||
override def && (literal: Boolean): Column = this && lit(literal)
|
||||
|
||||
/**
|
||||
* Bitwise AND with an expression.
|
||||
|
@ -316,7 +317,7 @@ class Column(
|
|||
/**
|
||||
* Bitwise AND with a literal value.
|
||||
*/
|
||||
override def & (literal: Any): Column = this & Literal.anyToLiteral(literal)
|
||||
override def & (literal: Any): Column = this & lit(literal)
|
||||
|
||||
/**
|
||||
* Bitwise OR with an expression.
|
||||
|
@ -326,7 +327,7 @@ class Column(
|
|||
/**
|
||||
* Bitwise OR with a literal value.
|
||||
*/
|
||||
override def | (literal: Any): Column = this | Literal.anyToLiteral(literal)
|
||||
override def | (literal: Any): Column = this | lit(literal)
|
||||
|
||||
/**
|
||||
* Bitwise XOR with an expression.
|
||||
|
@ -336,7 +337,7 @@ class Column(
|
|||
/**
|
||||
* Bitwise XOR with a literal value.
|
||||
*/
|
||||
override def ^ (literal: Any): Column = this ^ Literal.anyToLiteral(literal)
|
||||
override def ^ (literal: Any): Column = this ^ lit(literal)
|
||||
|
||||
/**
|
||||
* Sum of this expression and another expression.
|
||||
|
@ -354,10 +355,10 @@ class Column(
|
|||
* people.select( people("height") + 10 )
|
||||
* }}}
|
||||
*/
|
||||
override def + (literal: Any): Column = this + Literal.anyToLiteral(literal)
|
||||
override def + (literal: Any): Column = this + lit(literal)
|
||||
|
||||
/**
|
||||
* Subtraction. Substract the other expression from this expression.
|
||||
* Subtraction. Subtract the other expression from this expression.
|
||||
* {{{
|
||||
* // The following selects the difference between people's height and their weight.
|
||||
* people.select( people("height") - people("weight") )
|
||||
|
@ -366,16 +367,16 @@ class Column(
|
|||
override def - (other: Column): Column = Subtract(expr, other.expr)
|
||||
|
||||
/**
|
||||
* Subtraction. Substract a literal value from this expression.
|
||||
* Subtraction. Subtract a literal value from this expression.
|
||||
* {{{
|
||||
* // The following selects a person's height and substract it by 10.
|
||||
* // The following selects a person's height and subtract it by 10.
|
||||
* people.select( people("height") - 10 )
|
||||
* }}}
|
||||
*/
|
||||
override def - (literal: Any): Column = this - Literal.anyToLiteral(literal)
|
||||
override def - (literal: Any): Column = this - lit(literal)
|
||||
|
||||
/**
|
||||
* Multiply this expression and another expression.
|
||||
* Multiplication of this expression and another expression.
|
||||
* {{{
|
||||
* // The following multiplies a person's height by their weight.
|
||||
* people.select( people("height") * people("weight") )
|
||||
|
@ -384,16 +385,16 @@ class Column(
|
|||
override def * (other: Column): Column = Multiply(expr, other.expr)
|
||||
|
||||
/**
|
||||
* Multiply this expression and a literal value.
|
||||
* Multiplication this expression and a literal value.
|
||||
* {{{
|
||||
* // The following multiplies a person's height by 10.
|
||||
* people.select( people("height") * 10 )
|
||||
* }}}
|
||||
*/
|
||||
override def * (literal: Any): Column = this * Literal.anyToLiteral(literal)
|
||||
override def * (literal: Any): Column = this * lit(literal)
|
||||
|
||||
/**
|
||||
* Divide this expression by another expression.
|
||||
* Division this expression by another expression.
|
||||
* {{{
|
||||
* // The following divides a person's height by their weight.
|
||||
* people.select( people("height") / people("weight") )
|
||||
|
@ -402,13 +403,13 @@ class Column(
|
|||
override def / (other: Column): Column = Divide(expr, other.expr)
|
||||
|
||||
/**
|
||||
* Divide this expression by a literal value.
|
||||
* Division this expression by a literal value.
|
||||
* {{{
|
||||
* // The following divides a person's height by 10.
|
||||
* people.select( people("height") / 10 )
|
||||
* }}}
|
||||
*/
|
||||
override def / (literal: Any): Column = this / Literal.anyToLiteral(literal)
|
||||
override def / (literal: Any): Column = this / lit(literal)
|
||||
|
||||
/**
|
||||
* Modulo (a.k.a. remainder) expression.
|
||||
|
@ -418,7 +419,7 @@ class Column(
|
|||
/**
|
||||
* Modulo (a.k.a. remainder) expression.
|
||||
*/
|
||||
override def % (literal: Any): Column = this % Literal.anyToLiteral(literal)
|
||||
override def % (literal: Any): Column = this % lit(literal)
|
||||
|
||||
|
||||
/**
|
||||
|
@ -428,43 +429,67 @@ class Column(
|
|||
@scala.annotation.varargs
|
||||
override def in(list: Column*): Column = In(expr, list.map(_.expr))
|
||||
|
||||
override def like(other: Column): Column = Like(expr, other.expr)
|
||||
|
||||
override def like(literal: String): Column = this.like(Literal.anyToLiteral(literal))
|
||||
|
||||
override def rlike(other: Column): Column = RLike(expr, other.expr)
|
||||
|
||||
override def rlike(literal: String): Column = this.rlike(Literal.anyToLiteral(literal))
|
||||
override def like(literal: String): Column = Like(expr, lit(literal).expr)
|
||||
|
||||
override def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
|
||||
|
||||
/**
|
||||
* An expression that gets an
|
||||
* @param ordinal
|
||||
* @return
|
||||
*/
|
||||
override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal))
|
||||
|
||||
override def getItem(ordinal: Column): Column = GetItem(expr, ordinal.expr)
|
||||
|
||||
/**
|
||||
* An expression that gets a field by name in a [[StructField]].
|
||||
*/
|
||||
override def getField(fieldName: String): Column = GetField(expr, fieldName)
|
||||
|
||||
|
||||
/**
|
||||
* An expression that returns a substring.
|
||||
* @param startPos expression for the starting position.
|
||||
* @param len expression for the length of the substring.
|
||||
*/
|
||||
override def substr(startPos: Column, len: Column): Column =
|
||||
Substring(expr, startPos.expr, len.expr)
|
||||
|
||||
override def substr(startPos: Int, len: Int): Column =
|
||||
this.substr(Literal.anyToLiteral(startPos), Literal.anyToLiteral(len))
|
||||
/**
|
||||
* An expression that returns a substring.
|
||||
* @param startPos starting position.
|
||||
* @param len length of the substring.
|
||||
*/
|
||||
override def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len))
|
||||
|
||||
override def contains(other: Column): Column = Contains(expr, other.expr)
|
||||
|
||||
override def contains(literal: Any): Column = this.contains(Literal.anyToLiteral(literal))
|
||||
override def contains(literal: Any): Column = this.contains(lit(literal))
|
||||
|
||||
|
||||
override def startsWith(other: Column): Column = StartsWith(expr, other.expr)
|
||||
|
||||
override def startsWith(literal: String): Column = this.startsWith(Literal.anyToLiteral(literal))
|
||||
override def startsWith(literal: String): Column = this.startsWith(lit(literal))
|
||||
|
||||
override def endsWith(other: Column): Column = EndsWith(expr, other.expr)
|
||||
|
||||
override def endsWith(literal: String): Column = this.endsWith(Literal.anyToLiteral(literal))
|
||||
override def endsWith(literal: String): Column = this.endsWith(lit(literal))
|
||||
|
||||
/**
|
||||
* Gives the column an alias.
|
||||
* {{{
|
||||
* // Renames colA to colB in select output.
|
||||
* df.select($"colA".as("colB"))
|
||||
* }}}
|
||||
*/
|
||||
override def as(alias: String): Column = Alias(expr, alias)()
|
||||
|
||||
/**
|
||||
* Casts the column to a different data type.
|
||||
* {{{
|
||||
* // Casts colA to IntegerType.
|
||||
* import org.apache.spark.sql.types.IntegerType
|
||||
* df.select(df("colA").as(IntegerType))
|
||||
* }}}
|
||||
*/
|
||||
override def cast(to: DataType): Column = Cast(expr, to)
|
||||
|
||||
override def desc: Column = SortOrder(expr, Descending)
|
||||
|
|
|
@ -17,24 +17,22 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.util.{List => JList}
|
||||
|
||||
import scala.language.implicitConversions
|
||||
import scala.reflect.ClassTag
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import java.util.{ArrayList, List => JList}
|
||||
|
||||
import com.fasterxml.jackson.core.JsonFactory
|
||||
import net.razorvine.pickle.Pickler
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.api.python.SerDeUtil
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
|
||||
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
|
||||
|
@ -53,7 +51,8 @@ import org.apache.spark.util.Utils
|
|||
* }}}
|
||||
*
|
||||
* Once created, it can be manipulated using the various domain-specific-language (DSL) functions
|
||||
* defined in: [[DataFrame]] (this class), [[Column]], and [[dsl]] for Scala DSL.
|
||||
* defined in: [[DataFrame]] (this class), [[Column]], [[api.scala.dsl]] for Scala DSL, and
|
||||
* [[api.java.dsl]] for Java DSL.
|
||||
*
|
||||
* To select a column from the data frame, use the apply method:
|
||||
* {{{
|
||||
|
@ -110,14 +109,14 @@ class DataFrame protected[sql](
|
|||
new DataFrame(sqlContext, logicalPlan, true)
|
||||
}
|
||||
|
||||
/** Return the list of numeric columns, useful for doing aggregation. */
|
||||
/** Returns the list of numeric columns, useful for doing aggregation. */
|
||||
protected[sql] def numericColumns: Seq[Expression] = {
|
||||
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
|
||||
logicalPlan.resolve(n.name, sqlContext.analyzer.resolver).get
|
||||
}
|
||||
}
|
||||
|
||||
/** Resolve a column name into a Catalyst [[NamedExpression]]. */
|
||||
/** Resolves a column name into a Catalyst [[NamedExpression]]. */
|
||||
protected[sql] def resolve(colName: String): NamedExpression = {
|
||||
logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(
|
||||
throw new RuntimeException(s"""Cannot resolve column name "$colName""""))
|
||||
|
@ -128,22 +127,22 @@ class DataFrame protected[sql](
|
|||
def toSchemaRDD: DataFrame = this
|
||||
|
||||
/**
|
||||
* Return the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala.
|
||||
* Returns the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala.
|
||||
*/
|
||||
def toDataFrame: DataFrame = this
|
||||
|
||||
/** Return the schema of this [[DataFrame]]. */
|
||||
/** Returns the schema of this [[DataFrame]]. */
|
||||
override def schema: StructType = queryExecution.analyzed.schema
|
||||
|
||||
/** Return all column names and their data types as an array. */
|
||||
/** Returns all column names and their data types as an array. */
|
||||
override def dtypes: Array[(String, String)] = schema.fields.map { field =>
|
||||
(field.name, field.dataType.toString)
|
||||
}
|
||||
|
||||
/** Return all column names as an array. */
|
||||
/** Returns all column names as an array. */
|
||||
override def columns: Array[String] = schema.fields.map(_.name)
|
||||
|
||||
/** Print the schema to the console in a nice tree format. */
|
||||
/** Prints the schema to the console in a nice tree format. */
|
||||
override def printSchema(): Unit = println(schema.treeString)
|
||||
|
||||
/**
|
||||
|
@ -187,7 +186,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] sorted by the specified column, in ascending column.
|
||||
* Returns a new [[DataFrame]] sorted by the specified column, in ascending column.
|
||||
* {{{
|
||||
* // The following 3 are equivalent
|
||||
* df.sort("sortcol")
|
||||
|
@ -200,7 +199,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] sorted by the given expressions. For example:
|
||||
* Returns a new [[DataFrame]] sorted by the given expressions. For example:
|
||||
* {{{
|
||||
* df.sort($"col1", $"col2".desc)
|
||||
* }}}
|
||||
|
@ -219,7 +218,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] sorted by the given expressions.
|
||||
* Returns a new [[DataFrame]] sorted by the given expressions.
|
||||
* This is an alias of the `sort` function.
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
|
@ -228,7 +227,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Selecting a single column and return it as a [[Column]].
|
||||
* Selects a single column and return it as a [[Column]].
|
||||
*/
|
||||
override def apply(colName: String): Column = colName match {
|
||||
case "*" =>
|
||||
|
@ -239,7 +238,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Selecting a set of expressions, wrapped in a Product.
|
||||
* Selects a set of expressions, wrapped in a Product.
|
||||
* {{{
|
||||
* // The following two are equivalent:
|
||||
* df.apply(($"colA", $"colB" + 1))
|
||||
|
@ -250,17 +249,17 @@ class DataFrame protected[sql](
|
|||
require(projection.productArity >= 1)
|
||||
select(projection.productIterator.map {
|
||||
case c: Column => c
|
||||
case o: Any => new Column(Some(sqlContext), None, LiteralExpr(o))
|
||||
case o: Any => new Column(Some(sqlContext), None, Literal(o))
|
||||
}.toSeq :_*)
|
||||
}
|
||||
|
||||
/**
|
||||
* Alias the current [[DataFrame]].
|
||||
* Returns a new [[DataFrame]] with an alias set.
|
||||
*/
|
||||
override def as(name: String): DataFrame = Subquery(name, logicalPlan)
|
||||
|
||||
/**
|
||||
* Selecting a set of expressions.
|
||||
* Selects a set of expressions.
|
||||
* {{{
|
||||
* df.select($"colA", $"colB" + 1)
|
||||
* }}}
|
||||
|
@ -277,7 +276,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Selecting a set of columns. This is a variant of `select` that can only select
|
||||
* Selects a set of columns. This is a variant of `select` that can only select
|
||||
* existing columns using column names (i.e. cannot construct expressions).
|
||||
*
|
||||
* {{{
|
||||
|
@ -292,7 +291,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Filtering rows using the given condition.
|
||||
* Filters rows using the given condition.
|
||||
* {{{
|
||||
* // The following are equivalent:
|
||||
* peopleDf.filter($"age" > 15)
|
||||
|
@ -305,7 +304,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Filtering rows using the given condition. This is an alias for `filter`.
|
||||
* Filters rows using the given condition. This is an alias for `filter`.
|
||||
* {{{
|
||||
* // The following are equivalent:
|
||||
* peopleDf.filter($"age" > 15)
|
||||
|
@ -316,7 +315,7 @@ class DataFrame protected[sql](
|
|||
override def where(condition: Column): DataFrame = filter(condition)
|
||||
|
||||
/**
|
||||
* Filtering rows using the given condition. This is a shorthand meant for Scala.
|
||||
* Filters rows using the given condition. This is a shorthand meant for Scala.
|
||||
* {{{
|
||||
* // The following are equivalent:
|
||||
* peopleDf.filter($"age" > 15)
|
||||
|
@ -327,7 +326,7 @@ class DataFrame protected[sql](
|
|||
override def apply(condition: Column): DataFrame = filter(condition)
|
||||
|
||||
/**
|
||||
* Group the [[DataFrame]] using the specified columns, so we can run aggregation on them.
|
||||
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
|
||||
* See [[GroupedDataFrame]] for all the available aggregate functions.
|
||||
*
|
||||
* {{{
|
||||
|
@ -347,7 +346,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Group the [[DataFrame]] using the specified columns, so we can run aggregation on them.
|
||||
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
|
||||
* See [[GroupedDataFrame]] for all the available aggregate functions.
|
||||
*
|
||||
* This is a variant of groupBy that can only group by existing columns using column names
|
||||
|
@ -371,7 +370,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Aggregate on the entire [[DataFrame]] without groups.
|
||||
* Aggregates on the entire [[DataFrame]] without groups.
|
||||
* {{
|
||||
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
|
||||
* df.agg(Map("age" -> "max", "salary" -> "avg"))
|
||||
|
@ -381,7 +380,7 @@ class DataFrame protected[sql](
|
|||
override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
|
||||
|
||||
/**
|
||||
* Aggregate on the entire [[DataFrame]] without groups.
|
||||
* Aggregates on the entire [[DataFrame]] without groups.
|
||||
* {{
|
||||
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
|
||||
* df.agg(max($"age"), avg($"salary"))
|
||||
|
@ -392,31 +391,31 @@ class DataFrame protected[sql](
|
|||
override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*)
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] by taking the first `n` rows. The difference between this function
|
||||
* Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function
|
||||
* and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]].
|
||||
*/
|
||||
override def limit(n: Int): DataFrame = Limit(LiteralExpr(n), logicalPlan)
|
||||
override def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan)
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] containing union of rows in this frame and another frame.
|
||||
* Returns a new [[DataFrame]] containing union of rows in this frame and another frame.
|
||||
* This is equivalent to `UNION ALL` in SQL.
|
||||
*/
|
||||
override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan)
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] containing rows only in both this frame and another frame.
|
||||
* Returns a new [[DataFrame]] containing rows only in both this frame and another frame.
|
||||
* This is equivalent to `INTERSECT` in SQL.
|
||||
*/
|
||||
override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan)
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] containing rows in this frame but not in another frame.
|
||||
* Returns a new [[DataFrame]] containing rows in this frame but not in another frame.
|
||||
* This is equivalent to `EXCEPT` in SQL.
|
||||
*/
|
||||
override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan)
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] by sampling a fraction of rows.
|
||||
* Returns a new [[DataFrame]] by sampling a fraction of rows.
|
||||
*
|
||||
* @param withReplacement Sample with replacement or not.
|
||||
* @param fraction Fraction of rows to generate.
|
||||
|
@ -427,7 +426,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] by sampling a fraction of rows, using a random seed.
|
||||
* Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed.
|
||||
*
|
||||
* @param withReplacement Sample with replacement or not.
|
||||
* @param fraction Fraction of rows to generate.
|
||||
|
@ -439,57 +438,63 @@ class DataFrame protected[sql](
|
|||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] by adding a column.
|
||||
* Returns a new [[DataFrame]] by adding a column.
|
||||
*/
|
||||
override def addColumn(colName: String, col: Column): DataFrame = {
|
||||
select(Column("*"), col.as(colName))
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the first `n` rows.
|
||||
* Returns the first `n` rows.
|
||||
*/
|
||||
override def head(n: Int): Array[Row] = limit(n).collect()
|
||||
|
||||
/**
|
||||
* Return the first row.
|
||||
* Returns the first row.
|
||||
*/
|
||||
override def head(): Row = head(1).head
|
||||
|
||||
/**
|
||||
* Return the first row. Alias for head().
|
||||
* Returns the first row. Alias for head().
|
||||
*/
|
||||
override def first(): Row = head()
|
||||
|
||||
/**
|
||||
* Returns a new RDD by applying a function to all rows of this DataFrame.
|
||||
*/
|
||||
override def map[R: ClassTag](f: Row => R): RDD[R] = {
|
||||
rdd.map(f)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new RDD by applying a function to each partition of this DataFrame.
|
||||
*/
|
||||
override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
|
||||
rdd.mapPartitions(f)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the first `n` rows in the [[DataFrame]].
|
||||
* Returns the first `n` rows in the [[DataFrame]].
|
||||
*/
|
||||
override def take(n: Int): Array[Row] = head(n)
|
||||
|
||||
/**
|
||||
* Return an array that contains all of [[Row]]s in this [[DataFrame]].
|
||||
* Returns an array that contains all of [[Row]]s in this [[DataFrame]].
|
||||
*/
|
||||
override def collect(): Array[Row] = rdd.collect()
|
||||
|
||||
/**
|
||||
* Return a Java list that contains all of [[Row]]s in this [[DataFrame]].
|
||||
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
|
||||
*/
|
||||
override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*)
|
||||
|
||||
/**
|
||||
* Return the number of rows in the [[DataFrame]].
|
||||
* Returns the number of rows in the [[DataFrame]].
|
||||
*/
|
||||
override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
|
||||
|
||||
/**
|
||||
* Return a new [[DataFrame]] that has exactly `numPartitions` partitions.
|
||||
* Returns a new [[DataFrame]] that has exactly `numPartitions` partitions.
|
||||
*/
|
||||
override def repartition(numPartitions: Int): DataFrame = {
|
||||
sqlContext.applySchema(rdd.repartition(numPartitions), schema)
|
||||
|
@ -546,7 +551,7 @@ class DataFrame protected[sql](
|
|||
* Creates a table from the the contents of this DataFrame. This will fail if the table already
|
||||
* exists.
|
||||
*
|
||||
* Note that this currently only works with DataFrame that are created from a HiveContext as
|
||||
* Note that this currently only works with DataFrames that are created from a HiveContext as
|
||||
* there is no notion of a persisted catalog in a standard SQL context. Instead you can write
|
||||
* an RDD out to a parquet file, and then register that file as a table. This "table" can then
|
||||
* be the target of an `insertInto`.
|
||||
|
@ -568,7 +573,7 @@ class DataFrame protected[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Return the content of the [[DataFrame]] as a RDD of JSON strings.
|
||||
* Returns the content of the [[DataFrame]] as a RDD of JSON strings.
|
||||
*/
|
||||
override def toJSON: RDD[String] = {
|
||||
val rowSchema = this.schema
|
||||
|
|
|
@ -1,98 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
object Literal {
|
||||
|
||||
/** Return a new boolean literal. */
|
||||
def apply(literal: Boolean): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new byte literal. */
|
||||
def apply(literal: Byte): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new short literal. */
|
||||
def apply(literal: Short): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new int literal. */
|
||||
def apply(literal: Int): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new long literal. */
|
||||
def apply(literal: Long): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new float literal. */
|
||||
def apply(literal: Float): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new double literal. */
|
||||
def apply(literal: Double): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new string literal. */
|
||||
def apply(literal: String): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new decimal literal. */
|
||||
def apply(literal: BigDecimal): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new decimal literal. */
|
||||
def apply(literal: java.math.BigDecimal): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new timestamp literal. */
|
||||
def apply(literal: java.sql.Timestamp): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new date literal. */
|
||||
def apply(literal: java.sql.Date): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new binary (byte array) literal. */
|
||||
def apply(literal: Array[Byte]): Column = new Column(LiteralExpr(literal))
|
||||
|
||||
/** Return a new null literal. */
|
||||
def apply(literal: Null): Column = new Column(LiteralExpr(null))
|
||||
|
||||
/**
|
||||
* Return a Column expression representing the literal value. Throws an exception if the
|
||||
* data type is not supported by SparkSQL.
|
||||
*/
|
||||
protected[sql] def anyToLiteral(literal: Any): Column = {
|
||||
// If the literal is a symbol, convert it into a Column.
|
||||
if (literal.isInstanceOf[Symbol]) {
|
||||
return dsl.symbolToColumn(literal.asInstanceOf[Symbol])
|
||||
}
|
||||
|
||||
val literalExpr = literal match {
|
||||
case v: Int => LiteralExpr(v, IntegerType)
|
||||
case v: Long => LiteralExpr(v, LongType)
|
||||
case v: Double => LiteralExpr(v, DoubleType)
|
||||
case v: Float => LiteralExpr(v, FloatType)
|
||||
case v: Byte => LiteralExpr(v, ByteType)
|
||||
case v: Short => LiteralExpr(v, ShortType)
|
||||
case v: String => LiteralExpr(v, StringType)
|
||||
case v: Boolean => LiteralExpr(v, BooleanType)
|
||||
case v: BigDecimal => LiteralExpr(Decimal(v), DecimalType.Unlimited)
|
||||
case v: java.math.BigDecimal => LiteralExpr(Decimal(v), DecimalType.Unlimited)
|
||||
case v: Decimal => LiteralExpr(v, DecimalType.Unlimited)
|
||||
case v: java.sql.Timestamp => LiteralExpr(v, TimestampType)
|
||||
case v: java.sql.Date => LiteralExpr(v, DateType)
|
||||
case v: Array[Byte] => LiteralExpr(v, BinaryType)
|
||||
case null => LiteralExpr(null, NullType)
|
||||
case _ =>
|
||||
throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal)
|
||||
}
|
||||
new Column(literalExpr)
|
||||
}
|
||||
}
|
|
@ -135,19 +135,19 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
* The following example registers a UDF in Java:
|
||||
* {{{
|
||||
* sqlContext.udf().register("myUDF",
|
||||
* new UDF2<Integer, String, String>() {
|
||||
* @Override
|
||||
* public String call(Integer arg1, String arg2) {
|
||||
* return arg2 + arg1;
|
||||
* }
|
||||
* }, DataTypes.StringType);
|
||||
* new UDF2<Integer, String, String>() {
|
||||
* @Override
|
||||
* public String call(Integer arg1, String arg2) {
|
||||
* return arg2 + arg1;
|
||||
* }
|
||||
* }, DataTypes.StringType);
|
||||
* }}}
|
||||
*
|
||||
* Or, to use Java 8 lambda syntax:
|
||||
* {{{
|
||||
* sqlContext.udf().register("myUDF",
|
||||
* (Integer arg1, String arg2) -> arg2 + arg1),
|
||||
* DataTypes.StringType);
|
||||
* (Integer arg1, String arg2) -> arg2 + arg1),
|
||||
* DataTypes.StringType);
|
||||
* }}}
|
||||
*/
|
||||
val udf: UDFRegistration = new UDFRegistration(this)
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel
|
|||
* An internal interface defining the RDD-like methods for [[DataFrame]].
|
||||
* Please use [[DataFrame]] directly, and do NOT use this.
|
||||
*/
|
||||
trait RDDApi[T] {
|
||||
private[sql] trait RDDApi[T] {
|
||||
|
||||
def cache(): this.type = persist()
|
||||
|
||||
|
@ -64,7 +64,7 @@ trait RDDApi[T] {
|
|||
* An internal interface defining data frame related methods in [[DataFrame]].
|
||||
* Please use [[DataFrame]] directly, and do NOT use this.
|
||||
*/
|
||||
trait DataFrameSpecificApi {
|
||||
private[sql] trait DataFrameSpecificApi {
|
||||
|
||||
def schema: StructType
|
||||
|
||||
|
@ -181,7 +181,7 @@ trait DataFrameSpecificApi {
|
|||
* An internal interface defining expression APIs for [[DataFrame]].
|
||||
* Please use [[DataFrame]] and [[Column]] directly, and do NOT use this.
|
||||
*/
|
||||
trait ExpressionApi {
|
||||
private[sql] trait ExpressionApi {
|
||||
|
||||
def isComputable: Boolean
|
||||
|
||||
|
@ -231,9 +231,7 @@ trait ExpressionApi {
|
|||
@scala.annotation.varargs
|
||||
def in(list: Column*): Column
|
||||
|
||||
def like(other: Column): Column
|
||||
def like(other: String): Column
|
||||
def rlike(other: Column): Column
|
||||
def rlike(other: String): Column
|
||||
|
||||
def contains(other: Column): Column
|
||||
|
@ -249,7 +247,6 @@ trait ExpressionApi {
|
|||
def isNull: Column
|
||||
def isNotNull: Column
|
||||
|
||||
def getItem(ordinal: Column): Column
|
||||
def getItem(ordinal: Int): Column
|
||||
def getField(fieldName: String): Column
|
||||
|
||||
|
@ -266,7 +263,7 @@ trait ExpressionApi {
|
|||
* An internal interface defining aggregation APIs for [[DataFrame]].
|
||||
* Please use [[DataFrame]] and [[GroupedDataFrame]] directly, and do NOT use this.
|
||||
*/
|
||||
trait GroupedDataFrameApi {
|
||||
private[sql] trait GroupedDataFrameApi {
|
||||
|
||||
def agg(exprs: Map[String, String]): DataFrame
|
||||
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.api.java;
|
||||
|
||||
import org.apache.spark.sql.Column;
|
||||
import org.apache.spark.sql.DataFrame;
|
||||
import org.apache.spark.sql.api.scala.dsl.package$;
|
||||
|
||||
|
||||
/**
|
||||
* Java version of the domain-specific functions available for {@link DataFrame}.
|
||||
*
|
||||
* The Scala version is at {@link org.apache.spark.sql.api.scala.dsl}.
|
||||
*/
|
||||
public class dsl {
|
||||
// NOTE: Update also the Scala version when we update this version.
|
||||
|
||||
private static package$ scalaDsl = package$.MODULE$;
|
||||
|
||||
/**
|
||||
* Creates a column of literal value.
|
||||
*/
|
||||
public static Column lit(Object literalValue) {
|
||||
return scalaDsl.lit(literalValue);
|
||||
}
|
||||
|
||||
public static Column sum(Column e) {
|
||||
return scalaDsl.sum(e);
|
||||
}
|
||||
|
||||
public static Column sumDistinct(Column e) {
|
||||
return scalaDsl.sumDistinct(e);
|
||||
}
|
||||
|
||||
public static Column avg(Column e) {
|
||||
return scalaDsl.avg(e);
|
||||
}
|
||||
|
||||
public static Column first(Column e) {
|
||||
return scalaDsl.first(e);
|
||||
}
|
||||
|
||||
public static Column last(Column e) {
|
||||
return scalaDsl.last(e);
|
||||
}
|
||||
|
||||
public static Column min(Column e) {
|
||||
return scalaDsl.min(e);
|
||||
}
|
||||
|
||||
public static Column max(Column e) {
|
||||
return scalaDsl.max(e);
|
||||
}
|
||||
|
||||
public static Column upper(Column e) {
|
||||
return scalaDsl.upper(e);
|
||||
}
|
||||
|
||||
public static Column lower(Column e) {
|
||||
return scalaDsl.lower(e);
|
||||
}
|
||||
|
||||
public static Column sqrt(Column e) {
|
||||
return scalaDsl.sqrt(e);
|
||||
}
|
||||
|
||||
public static Column abs(Column e) {
|
||||
return scalaDsl.abs(e);
|
||||
}
|
||||
}
|
|
@ -15,20 +15,26 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.sql.{Timestamp, Date}
|
||||
package org.apache.spark.sql.api.scala
|
||||
|
||||
import scala.language.implicitConversions
|
||||
import scala.reflect.runtime.universe.{TypeTag, typeTag}
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.DataType
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
||||
/**
|
||||
* Scala version of the domain specific functions available for [[DataFrame]].
|
||||
*
|
||||
* The Java-version is at [[api.java.dsl]].
|
||||
*/
|
||||
package object dsl {
|
||||
// NOTE: Update also the Java version when we update this version.
|
||||
|
||||
/** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
|
||||
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
|
||||
|
||||
/** Converts $"col name" into an [[Column]]. */
|
||||
|
@ -40,11 +46,40 @@ package object dsl {
|
|||
|
||||
private[this] implicit def toColumn(expr: Expression): Column = new Column(expr)
|
||||
|
||||
/**
|
||||
* Creates a [[Column]] of literal value.
|
||||
*/
|
||||
def lit(literal: Any): Column = {
|
||||
if (literal.isInstanceOf[Symbol]) {
|
||||
return new ColumnName(literal.asInstanceOf[Symbol].name)
|
||||
}
|
||||
|
||||
val literalExpr = literal match {
|
||||
case v: Boolean => Literal(v, BooleanType)
|
||||
case v: Byte => Literal(v, ByteType)
|
||||
case v: Short => Literal(v, ShortType)
|
||||
case v: Int => Literal(v, IntegerType)
|
||||
case v: Long => Literal(v, LongType)
|
||||
case v: Float => Literal(v, FloatType)
|
||||
case v: Double => Literal(v, DoubleType)
|
||||
case v: String => Literal(v, StringType)
|
||||
case v: BigDecimal => Literal(Decimal(v), DecimalType.Unlimited)
|
||||
case v: java.math.BigDecimal => Literal(Decimal(v), DecimalType.Unlimited)
|
||||
case v: Decimal => Literal(v, DecimalType.Unlimited)
|
||||
case v: java.sql.Timestamp => Literal(v, TimestampType)
|
||||
case v: java.sql.Date => Literal(v, DateType)
|
||||
case v: Array[Byte] => Literal(v, BinaryType)
|
||||
case null => Literal(null, NullType)
|
||||
case _ =>
|
||||
throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal)
|
||||
}
|
||||
new Column(literalExpr)
|
||||
}
|
||||
|
||||
def sum(e: Column): Column = Sum(e.expr)
|
||||
def sumDistinct(e: Column): Column = SumDistinct(e.expr)
|
||||
def count(e: Column): Column = Count(e.expr)
|
||||
|
||||
@scala.annotation.varargs
|
||||
def countDistinct(expr: Column, exprs: Column*): Column =
|
||||
CountDistinct((expr +: exprs).map(_.expr))
|
||||
|
||||
|
@ -59,38 +94,9 @@ package object dsl {
|
|||
def sqrt(e: Column): Column = Sqrt(e.expr)
|
||||
def abs(e: Column): Column = Abs(e.expr)
|
||||
|
||||
|
||||
// scalastyle:off
|
||||
|
||||
object literals {
|
||||
|
||||
implicit def booleanToLiteral(b: Boolean): Column = Literal(b)
|
||||
|
||||
implicit def byteToLiteral(b: Byte): Column = Literal(b)
|
||||
|
||||
implicit def shortToLiteral(s: Short): Column = Literal(s)
|
||||
|
||||
implicit def intToLiteral(i: Int): Column = Literal(i)
|
||||
|
||||
implicit def longToLiteral(l: Long): Column = Literal(l)
|
||||
|
||||
implicit def floatToLiteral(f: Float): Column = Literal(f)
|
||||
|
||||
implicit def doubleToLiteral(d: Double): Column = Literal(d)
|
||||
|
||||
implicit def stringToLiteral(s: String): Column = Literal(s)
|
||||
|
||||
implicit def dateToLiteral(d: Date): Column = Literal(d)
|
||||
|
||||
implicit def bigDecimalToLiteral(d: BigDecimal): Column = Literal(d.underlying())
|
||||
|
||||
implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Column = Literal(d)
|
||||
|
||||
implicit def timestampToLiteral(t: Timestamp): Column = Literal(t)
|
||||
|
||||
implicit def binaryToLiteral(a: Array[Byte]): Column = Literal(a)
|
||||
}
|
||||
|
||||
|
||||
/* Use the following code to generate:
|
||||
(0 to 22).map { x =>
|
||||
val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"})
|
|
@ -19,7 +19,7 @@ package org.apache.spark.sql
|
|||
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.columnar._
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}
|
||||
|
||||
|
@ -244,7 +244,7 @@ class ColumnExpressionSuite extends QueryTest {
|
|||
)
|
||||
|
||||
checkAnswer(
|
||||
testData.select(sqrt(Literal(null))),
|
||||
testData.select(sqrt(lit(null))),
|
||||
(1 to 100).map(_ => Row(null))
|
||||
)
|
||||
}
|
||||
|
@ -261,7 +261,7 @@ class ColumnExpressionSuite extends QueryTest {
|
|||
)
|
||||
|
||||
checkAnswer(
|
||||
testData.select(abs(Literal(null))),
|
||||
testData.select(abs(lit(null))),
|
||||
(1 to 100).map(_ => Row(null))
|
||||
)
|
||||
}
|
||||
|
@ -278,7 +278,7 @@ class ColumnExpressionSuite extends QueryTest {
|
|||
)
|
||||
|
||||
checkAnswer(
|
||||
testData.select(upper(Literal(null))),
|
||||
testData.select(upper(lit(null))),
|
||||
(1 to 100).map(n => Row(null))
|
||||
)
|
||||
}
|
||||
|
@ -295,7 +295,7 @@ class ColumnExpressionSuite extends QueryTest {
|
|||
)
|
||||
|
||||
checkAnswer(
|
||||
testData.select(lower(Literal(null))),
|
||||
testData.select(lower(lit(null))),
|
||||
(1 to 100).map(n => Row(null))
|
||||
)
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/* Implicits */
|
||||
|
@ -57,13 +57,13 @@ class DataFrameSuite extends QueryTest {
|
|||
|
||||
test("convert $\"attribute name\" into unresolved attribute") {
|
||||
checkAnswer(
|
||||
testData.where($"key" === Literal(1)).select($"value"),
|
||||
testData.where($"key" === lit(1)).select($"value"),
|
||||
Row("1"))
|
||||
}
|
||||
|
||||
test("convert Scala Symbol 'attrname into unresolved attribute") {
|
||||
checkAnswer(
|
||||
testData.where('key === Literal(1)).select('value),
|
||||
testData.where('key === lit(1)).select('value),
|
||||
Row("1"))
|
||||
}
|
||||
|
||||
|
@ -75,13 +75,13 @@ class DataFrameSuite extends QueryTest {
|
|||
|
||||
test("simple select") {
|
||||
checkAnswer(
|
||||
testData.where('key === Literal(1)).select('value),
|
||||
testData.where('key === lit(1)).select('value),
|
||||
Row("1"))
|
||||
}
|
||||
|
||||
test("select with functions") {
|
||||
checkAnswer(
|
||||
testData.select(sum('value), avg('value), count(Literal(1))),
|
||||
testData.select(sum('value), avg('value), count(lit(1))),
|
||||
Row(5050.0, 50.5, 100))
|
||||
|
||||
checkAnswer(
|
||||
|
@ -215,7 +215,7 @@ class DataFrameSuite extends QueryTest {
|
|||
)
|
||||
|
||||
checkAnswer(
|
||||
testData3.agg(count('a), count('b), count(Literal(1)), countDistinct('a), countDistinct('b)),
|
||||
testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)),
|
||||
Row(2, 1, 2, 2, 1)
|
||||
)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql
|
|||
import org.scalatest.BeforeAndAfterEach
|
||||
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
||||
import org.apache.spark.sql.execution.joins._
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
|
@ -136,8 +136,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|
|||
}
|
||||
|
||||
test("inner join, where, multiple matches") {
|
||||
val x = testData2.where($"a" === Literal(1)).as("x")
|
||||
val y = testData2.where($"a" === Literal(1)).as("y")
|
||||
val x = testData2.where($"a" === 1).as("x")
|
||||
val y = testData2.where($"a" === 1).as("y")
|
||||
checkAnswer(
|
||||
x.join(y).where($"x.a" === $"y.a"),
|
||||
Row(1,1,1,1) ::
|
||||
|
@ -148,8 +148,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|
|||
}
|
||||
|
||||
test("inner join, no matches") {
|
||||
val x = testData2.where($"a" === Literal(1)).as("x")
|
||||
val y = testData2.where($"a" === Literal(2)).as("y")
|
||||
val x = testData2.where($"a" === 1).as("x")
|
||||
val y = testData2.where($"a" === 2).as("y")
|
||||
checkAnswer(
|
||||
x.join(y).where($"x.a" === $"y.a"),
|
||||
Nil)
|
||||
|
@ -185,7 +185,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|
|||
Row(6, "F", null, null) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > Literal(1), "left"),
|
||||
upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left"),
|
||||
Row(1, "A", null, null) ::
|
||||
Row(2, "B", 2, "b") ::
|
||||
Row(3, "C", 3, "c") ::
|
||||
|
@ -194,7 +194,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|
|||
Row(6, "F", null, null) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > Literal(1), "left"),
|
||||
upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left"),
|
||||
Row(1, "A", null, null) ::
|
||||
Row(2, "B", 2, "b") ::
|
||||
Row(3, "C", 3, "c") ::
|
||||
|
@ -247,7 +247,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|
|||
Row(null, null, 5, "E") ::
|
||||
Row(null, null, 6, "F") :: Nil)
|
||||
checkAnswer(
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > Literal(1), "right"),
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right"),
|
||||
Row(null, null, 1, "A") ::
|
||||
Row(2, "b", 2, "B") ::
|
||||
Row(3, "c", 3, "C") ::
|
||||
|
@ -255,7 +255,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|
|||
Row(null, null, 5, "E") ::
|
||||
Row(null, null, 6, "F") :: Nil)
|
||||
checkAnswer(
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > Literal(1), "right"),
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right"),
|
||||
Row(null, null, 1, "A") ::
|
||||
Row(2, "b", 2, "B") ::
|
||||
Row(3, "c", 3, "C") ::
|
||||
|
@ -298,8 +298,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|
|||
}
|
||||
|
||||
test("full outer join") {
|
||||
upperCaseData.where('N <= Literal(4)).registerTempTable("left")
|
||||
upperCaseData.where('N >= Literal(3)).registerTempTable("right")
|
||||
upperCaseData.where('N <= 4).registerTempTable("left")
|
||||
upperCaseData.where('N >= 3).registerTempTable("right")
|
||||
|
||||
val left = UnresolvedRelation(Seq("left"), None)
|
||||
val right = UnresolvedRelation(Seq("right"), None)
|
||||
|
@ -314,7 +314,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|
|||
Row(null, null, 6, "F") :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== Literal(3)), "full"),
|
||||
left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== 3), "full"),
|
||||
Row(1, "A", null, null) ::
|
||||
Row(2, "B", null, null) ::
|
||||
Row(3, "C", null, null) ::
|
||||
|
@ -324,7 +324,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|
|||
Row(null, null, 6, "F") :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== Literal(3)), "full"),
|
||||
left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== 3), "full"),
|
||||
Row(1, "A", null, null) ::
|
||||
Row(2, "B", null, null) ::
|
||||
Row(3, "C", null, null) ::
|
||||
|
|
|
@ -21,7 +21,7 @@ import java.util.TimeZone
|
|||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.catalyst.errors.TreeNodeException
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.types._
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql
|
|||
import java.sql.Timestamp
|
||||
|
||||
import org.apache.spark.sql.catalyst.plans.logical
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.test._
|
||||
|
||||
/* Implicits */
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.dsl.StringToColumn
|
||||
import org.apache.spark.sql.api.scala.dsl.StringToColumn
|
||||
import org.apache.spark.sql.test._
|
||||
|
||||
/* Implicits */
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql
|
|||
import scala.beans.{BeanInfo, BeanProperty}
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.columnar
|
||||
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.catalyst.expressions.Row
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
|
|||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.sql.{SQLConf, execution}
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
|
|
|
@ -21,15 +21,16 @@ import java.sql.{Date, Timestamp}
|
|||
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{Literal, QueryTest, Row, SQLConf}
|
||||
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
|
||||
|
||||
class JsonSuite extends QueryTest {
|
||||
import org.apache.spark.sql.json.TestJsonData._
|
||||
|
||||
TestJsonData
|
||||
|
||||
test("Type promotion") {
|
||||
|
@ -464,8 +465,8 @@ class JsonSuite extends QueryTest {
|
|||
// in the Project.
|
||||
checkAnswer(
|
||||
jsonDF.
|
||||
where('num_str > Literal(BigDecimal("92233720368547758060"))).
|
||||
select(('num_str + Literal(1.2)).as("num")),
|
||||
where('num_str > BigDecimal("92233720368547758060")).
|
||||
select(('num_str + 1.2).as("num")),
|
||||
Row(new java.math.BigDecimal("92233720368547758061.2"))
|
||||
)
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ import parquet.schema.{MessageType, MessageTypeParser}
|
|||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.expressions.Row
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
|
|
|
@ -15,4 +15,4 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hive;
|
||||
package org.apache.spark.sql.hive;
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars
|
|||
import org.apache.spark.{SparkFiles, SparkException}
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Project
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.hive._
|
||||
import org.apache.spark.sql.hive.test.TestHive
|
||||
import org.apache.spark.sql.hive.test.TestHive._
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.hive.execution
|
||||
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.dsl._
|
||||
import org.apache.spark.sql.api.scala.dsl._
|
||||
import org.apache.spark.sql.hive.test.TestHive
|
||||
import org.apache.spark.sql.hive.test.TestHive._
|
||||
|
||||
|
|
Loading…
Reference in a new issue