[SPARK-7927] whitespace fixes for SQL core.

So we can enable a whitespace enforcement rule in the style checker to save code review time.

Author: Reynold Xin <rxin@databricks.com>

Closes #6477 from rxin/whitespace-sql-core and squashes the following commits:

ce6e369 [Reynold Xin] Fixed tests.
6095fed [Reynold Xin] [SPARK-7927] whitespace fixes for SQL core.
This commit is contained in:
Reynold Xin 2015-05-28 20:10:21 -07:00
parent 04616b1a2f
commit ff44c711ab
37 changed files with 160 additions and 158 deletions

View file

@ -349,7 +349,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @group expr_ops
* @since 1.4.0
*/
def when(condition: Column, value: Any):Column = this.expr match {
def when(condition: Column, value: Any): Column = this.expr match {
case CaseWhen(branches: Seq[Expression]) =>
CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr))
case _ =>
@ -378,7 +378,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @group expr_ops
* @since 1.4.0
*/
def otherwise(value: Any):Column = this.expr match {
def otherwise(value: Any): Column = this.expr match {
case CaseWhen(branches: Seq[Expression]) =>
if (branches.size % 2 == 0) {
CaseWhen(branches :+ lit(value).expr)

View file

@ -255,7 +255,7 @@ class DataFrame private[sql](
val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) =>
Column(oldAttribute).as(newName)
}
select(newCols :_*)
select(newCols : _*)
}
/**
@ -500,7 +500,7 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def sort(sortCol: String, sortCols: String*): DataFrame = {
sort((sortCol +: sortCols).map(apply) :_*)
sort((sortCol +: sortCols).map(apply) : _*)
}
/**
@ -531,7 +531,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols :_*)
def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*)
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
@ -540,7 +540,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs :_*)
def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*)
/**
* Selects column based on the column name and return it as a [[Column]].
@ -611,7 +611,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) :_*)
def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*)
/**
* Selects a set of SQL expressions. This is a variant of `select` that accepts
@ -825,7 +825,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
groupBy().agg(aggExpr, aggExprs :_*)
groupBy().agg(aggExpr, aggExprs : _*)
}
/**
@ -863,7 +863,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*)
def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*)
/**
* Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function
@ -1039,7 +1039,7 @@ class DataFrame private[sql](
val name = field.name
if (resolver(name, colName)) col.as(colName) else Column(name)
}
select(colNames :_*)
select(colNames : _*)
} else {
select(Column("*"), col.as(colName))
}
@ -1262,7 +1262,7 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*)
override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*)
/**
* Returns the number of rows in the [[DataFrame]].

View file

@ -28,5 +28,5 @@ private[sql] case class DataFrameHolder(df: DataFrame) {
// `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
def toDF(): DataFrame = df
def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*)
def toDF(colNames: String*): DataFrame = df.toDF(colNames : _*)
}

View file

@ -247,7 +247,7 @@ class GroupedData protected[sql](
*/
@scala.annotation.varargs
def mean(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames:_*)(Average)
aggregateNumericColumns(colNames : _*)(Average)
}
/**
@ -259,7 +259,7 @@ class GroupedData protected[sql](
*/
@scala.annotation.varargs
def max(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames:_*)(Max)
aggregateNumericColumns(colNames : _*)(Max)
}
/**
@ -271,7 +271,7 @@ class GroupedData protected[sql](
*/
@scala.annotation.varargs
def avg(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames:_*)(Average)
aggregateNumericColumns(colNames : _*)(Average)
}
/**
@ -283,7 +283,7 @@ class GroupedData protected[sql](
*/
@scala.annotation.varargs
def min(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames:_*)(Min)
aggregateNumericColumns(colNames : _*)(Min)
}
/**
@ -295,6 +295,6 @@ class GroupedData protected[sql](
*/
@scala.annotation.varargs
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames:_*)(Sum)
aggregateNumericColumns(colNames : _*)(Sum)
}
}

View file

@ -298,7 +298,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
implicit class StringToColumn(val sc: StringContext) {
def $(args: Any*): ColumnName = {
new ColumnName(sc.s(args :_*))
new ColumnName(sc.s(args : _*))
}
}

View file

@ -54,15 +54,15 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr
}
}
protected val AS = Keyword("AS")
protected val CACHE = Keyword("CACHE")
protected val CLEAR = Keyword("CLEAR")
protected val IN = Keyword("IN")
protected val LAZY = Keyword("LAZY")
protected val SET = Keyword("SET")
protected val SHOW = Keyword("SHOW")
protected val TABLE = Keyword("TABLE")
protected val TABLES = Keyword("TABLES")
protected val AS = Keyword("AS")
protected val CACHE = Keyword("CACHE")
protected val CLEAR = Keyword("CLEAR")
protected val IN = Keyword("IN")
protected val LAZY = Keyword("LAZY")
protected val SET = Keyword("SET")
protected val SHOW = Keyword("SHOW")
protected val TABLE = Keyword("TABLE")
protected val TABLES = Keyword("TABLES")
protected val UNCACHE = Keyword("UNCACHE")
override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others

View file

@ -236,7 +236,7 @@ private[sql] case class InMemoryColumnarTableScan(
case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound
case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l
case IsNull(a: Attribute) => statsFor(a).nullCount > 0
case IsNull(a: Attribute) => statsFor(a).nullCount > 0
case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
}

View file

@ -296,7 +296,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
.sliding(2)
.map {
case Seq(a) => true
case Seq(a,b) => a compatibleWith b
case Seq(a, b) => a.compatibleWith(b)
}.exists(!_)
// Adds Exchange or Sort operators as required

View file

@ -243,8 +243,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case (predicate, None) => predicate
// Filter needs to be applied above when it contains partitioning
// columns
case (predicate, _) if(!predicate.references.map(_.name).toSet
.intersect (partitionColNames).isEmpty) => predicate
case (predicate, _)
if !predicate.references.map(_.name).toSet.intersect(partitionColNames).isEmpty =>
predicate
}
}
} else {
@ -270,7 +271,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
projectList,
filters,
identity[Seq[Expression]], // All filters still need to be evaluated.
InMemoryColumnarTableScan(_, filters, mem)) :: Nil
InMemoryColumnarTableScan(_, filters, mem)) :: Nil
case _ => Nil
}
}

View file

@ -39,7 +39,7 @@ case class BroadcastLeftSemiJoinHash(
override def output: Seq[Attribute] = left.output
protected override def doExecute(): RDD[Row] = {
val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator
val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator
val hashSet = new java.util.HashSet[Row]()
var currentRow: Row = null

View file

@ -89,7 +89,7 @@ private[sql] object FrequentItems extends Logging {
(name, originalSchema.fields(index).dataType)
}
val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)(
val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
while (i < numCols) {
@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging {
}
)
val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
val resultRow = Row(justItems:_*)
val resultRow = Row(justItems : _*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
StructField(v._1 + "_freqItems", ArrayType(v._2, false))

View file

@ -187,7 +187,7 @@ object functions {
*/
@scala.annotation.varargs
def countDistinct(columnName: String, columnNames: String*): Column =
countDistinct(Column(columnName), columnNames.map(Column.apply) :_*)
countDistinct(Column(columnName), columnNames.map(Column.apply) : _*)
/**
* Aggregate function: returns the approximate number of distinct items in a group.

View file

@ -52,6 +52,7 @@ private[sql] object JDBCRDD extends Logging {
scale: Int,
signed: Boolean): DataType = {
val answer = sqlType match {
// scalastyle:off
case java.sql.Types.ARRAY => null
case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType.Unlimited }
case java.sql.Types.BINARY => BinaryType
@ -92,7 +93,8 @@ private[sql] object JDBCRDD extends Logging {
case java.sql.Types.TINYINT => IntegerType
case java.sql.Types.VARBINARY => BinaryType
case java.sql.Types.VARCHAR => StringType
case _ => null
case _ => null
// scalastyle:on
}
if (answer == null) throw new SQLException("Unsupported type " + sqlType)
@ -323,19 +325,19 @@ private[sql] class JDBCRDD(
*/
def getConversions(schema: StructType): Array[JDBCConversion] = {
schema.fields.map(sf => sf.dataType match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Unlimited => DecimalConversion(None)
case DecimalType.Fixed(d) => DecimalConversion(Some(d))
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
case LongType =>
case DecimalType.Fixed(d) => DecimalConversion(Some(d))
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
case LongType =>
if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
}).toArray
}
@ -376,8 +378,8 @@ private[sql] class JDBCRDD(
while (i < conversions.length) {
val pos = i + 1
conversions(i) match {
case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
case DateConversion =>
case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
case DateConversion =>
// DateUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos)
if (dateVal != null) {
@ -407,14 +409,14 @@ private[sql] class JDBCRDD(
} else {
mutableRow.update(i, Decimal(decimalVal))
}
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
case StringConversion => mutableRow.setString(i, rs.getString(pos))
case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos))
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
case StringConversion => mutableRow.setString(i, rs.getString(pos))
case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos))
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
case BinaryLongConversion => {
val bytes = rs.getBytes(pos)
var ans = 0L

View file

@ -124,7 +124,7 @@ private[sql] object InferSchema {
case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull)
case ArrayType(struct: StructType, containsNull) =>
ArrayType(nullTypeToStringType(struct), containsNull)
case struct: StructType =>nullTypeToStringType(struct)
case struct: StructType => nullTypeToStringType(struct)
case other: DataType => other
}

View file

@ -33,7 +33,7 @@ private[sql] object JacksonGenerator {
*/
def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = {
def valWriter: (DataType, Any) => Unit = {
case (_, null) | (NullType, _) => gen.writeNull()
case (_, null) | (NullType, _) => gen.writeNull()
case (StringType, v: String) => gen.writeString(v)
case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
case (IntegerType, v: Int) => gen.writeNumber(v)
@ -48,16 +48,16 @@ private[sql] object JacksonGenerator {
case (DateType, v) => gen.writeString(v.toString)
case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v))
case (ArrayType(ty, _), v: Seq[_] ) =>
case (ArrayType(ty, _), v: Seq[_]) =>
gen.writeStartArray()
v.foreach(valWriter(ty,_))
v.foreach(valWriter(ty, _))
gen.writeEndArray()
case (MapType(kv,vv, _), v: Map[_,_]) =>
case (MapType(kv, vv, _), v: Map[_, _]) =>
gen.writeStartObject()
v.foreach { p =>
gen.writeFieldName(p._1.toString)
valWriter(vv,p._2)
valWriter(vv, p._2)
}
gen.writeEndObject()

View file

@ -141,7 +141,7 @@ private[sql] object JsonRDD extends Logging {
case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull)
case ArrayType(struct: StructType, containsNull) =>
ArrayType(nullTypeToStringType(struct), containsNull)
case struct: StructType =>nullTypeToStringType(struct)
case struct: StructType => nullTypeToStringType(struct)
case other: DataType => other
}
StructField(fieldName, newType, nullable)
@ -216,7 +216,7 @@ private[sql] object JsonRDD extends Logging {
case map: Map[_, _] => StructType(Nil)
// We have an array of arrays. If those element arrays do not have the same
// element types, we will return ArrayType[StringType].
case seq: Seq[_] => typeOfArray(seq)
case seq: Seq[_] => typeOfArray(seq)
case value => typeOfPrimitiveValue(value)
}
}.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2))
@ -406,7 +406,7 @@ private[sql] object JsonRDD extends Logging {
}
}
private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={
private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any = {
if (value == null) {
null
} else {
@ -434,7 +434,7 @@ private[sql] object JsonRDD extends Logging {
}
}
private def asRow(json: Map[String,Any], schema: StructType): Row = {
private def asRow(json: Map[String, Any], schema: StructType): Row = {
// TODO: Reuse the row instead of creating a new one for every record.
val row = new GenericMutableRow(schema.fields.length)
schema.fields.zipWithIndex.foreach {

View file

@ -480,7 +480,7 @@ private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverte
override def hasDictionarySupport: Boolean = true
override def setDictionary(dictionary: Dictionary):Unit =
override def setDictionary(dictionary: Dictionary): Unit =
dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes }
override def addValueFromDictionary(dictionaryId: Int): Unit =
@ -591,8 +591,8 @@ private[parquet] class CatalystArrayConverter(
CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME,
elementType,
false),
fieldIndex=0,
parent=this)
fieldIndex = 0,
parent = this)
override def getConverter(fieldIndex: Int): Converter = converter
@ -601,7 +601,7 @@ private[parquet] class CatalystArrayConverter(
override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = {
// fieldIndex is ignored (assumed to be zero but not checked)
if(value == null) {
if (value == null) {
throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!")
}
buffer += value
@ -654,8 +654,8 @@ private[parquet] class CatalystNativeArrayConverter(
CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME,
elementType,
false),
fieldIndex=0,
parent=this)
fieldIndex = 0,
parent = this)
override def getConverter(fieldIndex: Int): Converter = converter

View file

@ -541,7 +541,7 @@ private[parquet] class FilteringParquetRowInputFormat
val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
val filter: Filter = ParquetInputFormat.getFilter(configuration)
var rowGroupsDropped: Long = 0
var totalRowGroups: Long = 0
var totalRowGroups: Long = 0
// Ugly hack, stuck with it until PR:
// https://github.com/apache/incubator-parquet-mr/pull/17
@ -664,7 +664,7 @@ private[parquet] object FileSystemHelper {
s"ParquetTableOperations: path $path does not exist or is not a directory")
}
fs.globStatus(path)
.flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) }
.flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) }
.map(_.getPath)
}

View file

@ -489,7 +489,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
val children =
fs
.globStatus(path)
.flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) }
.flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) }
.filterNot { status =>
val name = status.getPath.getName
(name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE

View file

@ -130,7 +130,7 @@ private[sql] class DDLParser(
}
}
protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"
protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"
/*
* describe [extended] table avroTable
@ -138,7 +138,7 @@ private[sql] class DDLParser(
*/
protected lazy val describeTable: Parser[LogicalPlan] =
(DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ {
case e ~ db ~ tbl =>
case e ~ db ~ tbl =>
val tblIdentifier = db match {
case Some(dbName) =>
Seq(dbName, tbl)
@ -171,7 +171,7 @@ private[sql] class DDLParser(
}
protected lazy val pair: Parser[(String, String)] =
optionName ~ stringLit ^^ { case k ~ v => (k,v) }
optionName ~ stringLit ^^ { case k ~ v => (k, v) }
protected lazy val column: Parser[StructField] =
ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm =>
@ -239,7 +239,7 @@ private[sql] object ResolvedDataSource {
Some(partitionColumnsSchema(schema, partitionColumns))
}
val caseInsensitiveOptions= new CaseInsensitiveMap(options)
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val paths = {
val patternPath = new Path(caseInsensitiveOptions("path"))
SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray

View file

@ -28,14 +28,14 @@ class ColumnExpressionSuite extends QueryTest {
import org.apache.spark.sql.TestData._
test("single explode") {
val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select(explode('intList)),
Row(1) :: Row(2) :: Row(3) :: Nil)
}
test("explode and other columns") {
val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select($"a", explode('intList)),
@ -45,13 +45,13 @@ class ColumnExpressionSuite extends QueryTest {
checkAnswer(
df.select($"*", explode('intList)),
Row(1, Seq(1,2,3), 1) ::
Row(1, Seq(1,2,3), 2) ::
Row(1, Seq(1,2,3), 3) :: Nil)
Row(1, Seq(1, 2, 3), 1) ::
Row(1, Seq(1, 2, 3), 2) ::
Row(1, Seq(1, 2, 3), 3) :: Nil)
}
test("aliased explode") {
val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select(explode('intList).as('int)).select('int),
@ -79,7 +79,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("self join explode") {
val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
val exploded = df.select(explode('intList).as('i))
checkAnswer(

View file

@ -148,12 +148,12 @@ class DataFrameAggregateSuite extends QueryTest {
test("null count") {
checkAnswer(
testData3.groupBy('a).agg(count('b)),
Seq(Row(1,0), Row(2, 1))
Seq(Row(1, 0), Row(2, 1))
)
checkAnswer(
testData3.groupBy('a).agg(count('a + 'b)),
Seq(Row(1,0), Row(2, 1))
Seq(Row(1, 0), Row(2, 1))
)
checkAnswer(

View file

@ -59,7 +59,7 @@ class DataFrameSuite extends QueryTest {
}
test("rename nested groupby") {
val df = Seq((1,(1,1))).toDF()
val df = Seq((1, (1, 1))).toDF()
checkAnswer(
df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"),
@ -211,23 +211,23 @@ class DataFrameSuite extends QueryTest {
test("global sorting") {
checkAnswer(
testData2.orderBy('a.asc, 'b.asc),
Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)))
checkAnswer(
testData2.orderBy(asc("a"), desc("b")),
Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1)))
checkAnswer(
testData2.orderBy('a.asc, 'b.desc),
Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1)))
checkAnswer(
testData2.orderBy('a.desc, 'b.desc),
Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1)))
Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1)))
checkAnswer(
testData2.orderBy('a.desc, 'b.asc),
Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))
Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)))
checkAnswer(
arrayData.toDF().orderBy('data.getItem(0).asc),
@ -331,7 +331,7 @@ class DataFrameSuite extends QueryTest {
checkAnswer(
df,
testData.collect().toSeq)
assert(df.schema.map(_.name) === Seq("key","value"))
assert(df.schema.map(_.name) === Seq("key", "value"))
}
test("withColumnRenamed") {
@ -364,24 +364,24 @@ class DataFrameSuite extends QueryTest {
test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),
("Bob", 16, 176),
("Alice", 32, 164),
("David", 60, 192),
("Amy", 24, 180)).toDF("name", "age", "height")
("Amy", 24, 180)).toDF("name", "age", "height")
val describeResult = Seq(
Row("count", "4", "4"),
Row("mean", "33.0", "178.0"),
Row("stddev", "16.583123951777", "10.0"),
Row("min", "16", "164"),
Row("max", "60", "192"))
Row("count", "4", "4"),
Row("mean", "33.0", "178.0"),
Row("stddev", "16.583123951777", "10.0"),
Row("min", "16", "164"),
Row("max", "60", "192"))
val emptyDescribeResult = Seq(
Row("count", "0", "0"),
Row("mean", null, null),
Row("stddev", null, null),
Row("min", null, null),
Row("max", null, null))
Row("count", "0", "0"),
Row("mean", null, null),
Row("stddev", null, null),
Row("min", null, null),
Row("max", null, null))
def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)

View file

@ -167,10 +167,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
val y = testData2.where($"a" === 1).as("y")
checkAnswer(
x.join(y).where($"x.a" === $"y.a"),
Row(1,1,1,1) ::
Row(1,1,1,2) ::
Row(1,2,1,1) ::
Row(1,2,1,2) :: Nil
Row(1, 1, 1, 1) ::
Row(1, 1, 1, 2) ::
Row(1, 2, 1, 1) ::
Row(1, 2, 1, 2) :: Nil
)
}

View file

@ -28,7 +28,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
import org.apache.spark.sql.test.TestSQLContext.implicits._
val df =
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value")
sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value")
before {
df.registerTempTable("ListTablesSuiteTable")

View file

@ -53,7 +53,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("self join with aliases") {
Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df")
Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df")
checkAnswer(
sql(
@ -76,7 +76,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("self join with alias in agg") {
Seq(1,2,3)
Seq(1, 2, 3)
.map(i => (i, i.toString))
.toDF("int", "str")
.groupBy("str")
@ -113,7 +113,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
checkAnswer(
sql("SELECT a FROM testData2 SORT BY a"),
Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_))
Seq(1, 1, 2, 2, 3, 3).map(Row(_))
)
}
@ -354,7 +354,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("left semi greater than predicate") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
Seq(Row(3,1), Row(3,2))
Seq(Row(3, 1), Row(3, 2))
)
}
@ -371,16 +371,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("agg") {
checkAnswer(
sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"),
Seq(Row(1,3), Row(2,3), Row(3,3)))
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
}
test("literal in agg grouping expressions") {
checkAnswer(
sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
Seq(Row(1,2), Row(2,2), Row(3,2)))
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
checkAnswer(
sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
Seq(Row(1,2), Row(2,2), Row(3,2)))
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
}
test("aggregates with nulls") {
@ -405,19 +405,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
def sortTest(): Unit = {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"),
Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"),
Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1)))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"),
Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1)))
Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1)))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))
Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)))
checkAnswer(
sql("SELECT b FROM binaryData ORDER BY a ASC"),
@ -552,7 +552,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("average overflow") {
checkAnswer(
sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"),
Seq(Row(2147483645.0,1), Row(2.0,2)))
Seq(Row(2147483645.0, 1), Row(2.0, 2)))
}
test("count") {
@ -619,10 +619,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| (SELECT * FROM testData2 WHERE a = 1) x JOIN
| (SELECT * FROM testData2 WHERE a = 1) y
|WHERE x.a = y.a""".stripMargin),
Row(1,1,1,1) ::
Row(1,1,1,2) ::
Row(1,2,1,1) ::
Row(1,2,1,2) :: Nil)
Row(1, 1, 1, 1) ::
Row(1, 1, 1, 2) ::
Row(1, 2, 1, 1) ::
Row(1, 2, 1, 2) :: Nil)
}
test("inner join, no matches") {
@ -1266,22 +1266,22 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") {
checkAnswer(
sql("SELECT a + b FROM testData2 ORDER BY a"),
Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_))
Seq(2, 3, 3, 4, 4, 5).map(Row(_))
)
}
test("oder by asc by default when not specify ascending and descending") {
checkAnswer(
sql("SELECT a, b FROM testData2 ORDER BY a desc, b"),
Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2))
Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))
)
}
test("Supporting relational operator '<=>' in Spark SQL") {
val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil
val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil
val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i)))
rdd1.toDF().registerTempTable("nulldata1")
val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil
val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil
val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i)))
rdd2.toDF().registerTempTable("nulldata2")
checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " +
@ -1290,7 +1290,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("Multi-column COUNT(DISTINCT ...)") {
val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil
val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
rdd.toDF().registerTempTable("distinctData")
checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2))

View file

@ -80,14 +80,14 @@ class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3))
val rdd = sparkContext.parallelize(data :: Nil)
rdd.toDF().registerTempTable("reflectData")
assert(sql("SELECT * FROM reflectData").collect().head ===
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
new Timestamp(12345), Seq(1,2,3)))
new Timestamp(12345), Seq(1, 2, 3)))
}
test("query case class RDD with nulls") {

View file

@ -109,8 +109,8 @@ object TestData {
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
val arrayData =
TestSQLContext.sparkContext.parallelize(
ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) ::
ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
arrayData.toDF().registerTempTable("arrayData")
case class MapData(data: scala.collection.Map[Int, String])

View file

@ -38,7 +38,7 @@ class UDFSuite extends QueryTest {
}
test("TwoArgument UDF") {
udf.register("strLenScala", (_: String).length + (_:Int))
udf.register("strLenScala", (_: String).length + (_: Int))
assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}

View file

@ -73,7 +73,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
checkActualSize(TIMESTAMP, new Timestamp(0L), 12)
val binary = Array.fill[Byte](4)(0: Byte)
checkActualSize(BINARY, binary, 4 + 4)
checkActualSize(BINARY, binary, 4 + 4)
val generic = Map(1 -> "a")
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
@ -167,7 +167,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
val serializer = new SparkSqlSerializer(conf).newInstance()
val buffer = ByteBuffer.allocate(512)
val obj = CustomClass(Int.MaxValue,Long.MaxValue)
val obj = CustomClass(Int.MaxValue, Long.MaxValue)
val serializedObj = serializer.serialize(obj).array()
GENERIC.append(serializer.serialize(obj).array(), buffer)
@ -278,7 +278,7 @@ private[columnar] object CustomerSerializer extends Serializer[CustomClass] {
override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = {
val a = input.readInt()
val b = input.readLong()
CustomClass(a,b)
CustomClass(a, b)
}
}

View file

@ -27,8 +27,8 @@ import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types.AtomicType
class DictionaryEncodingSuite extends FunSuite {
testDictionaryEncoding(new IntColumnStats, INT)
testDictionaryEncoding(new LongColumnStats, LONG)
testDictionaryEncoding(new IntColumnStats, INT)
testDictionaryEncoding(new LongColumnStats, LONG)
testDictionaryEncoding(new StringColumnStats, STRING)
def testDictionaryEncoding[T <: AtomicType](

View file

@ -25,7 +25,7 @@ import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types.IntegralType
class IntegralDeltaSuite extends FunSuite {
testIntegralDelta(new IntColumnStats, INT, IntDelta)
testIntegralDelta(new IntColumnStats, INT, IntDelta)
testIntegralDelta(new LongColumnStats, LONG, LongDelta)
def testIntegralDelta[I <: IntegralType](
@ -116,7 +116,7 @@ class IntegralDeltaSuite extends FunSuite {
test(s"$scheme: simple case") {
val input = columnType match {
case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int)
case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int)
case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long)
}

View file

@ -26,11 +26,11 @@ import org.apache.spark.sql.types.AtomicType
class RunLengthEncodingSuite extends FunSuite {
testRunLengthEncoding(new NoopColumnStats, BOOLEAN)
testRunLengthEncoding(new ByteColumnStats, BYTE)
testRunLengthEncoding(new ShortColumnStats, SHORT)
testRunLengthEncoding(new IntColumnStats, INT)
testRunLengthEncoding(new LongColumnStats, LONG)
testRunLengthEncoding(new StringColumnStats, STRING)
testRunLengthEncoding(new ByteColumnStats, BYTE)
testRunLengthEncoding(new ShortColumnStats, SHORT)
testRunLengthEncoding(new IntColumnStats, INT)
testRunLengthEncoding(new LongColumnStats, LONG)
testRunLengthEncoding(new StringColumnStats, STRING)
def testRunLengthEncoding[T <: AtomicType](
columnStats: ColumnStats,

View file

@ -429,8 +429,8 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
}, testH2Dialect))
assert(agg.canHandle("jdbc:h2:xxx"))
assert(!agg.canHandle("jdbc:h2"))
assert(agg.getCatalystType(0,"",1,null) == Some(LongType))
assert(agg.getCatalystType(1,"",1,null) == Some(StringType))
assert(agg.getCatalystType(0, "", 1, null) == Some(LongType))
assert(agg.getCatalystType(1, "", 1, null) == Some(StringType))
}
}

View file

@ -522,7 +522,7 @@ class JsonSuite extends QueryTest {
Row(Seq(), "11", "[1,2,3]", Row(null), "[]") ::
Row(null, """{"field":false}""", null, null, "{}") ::
Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") ::
Row(Seq(7), "{}","""["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil
Row(Seq(7), "{}", """["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil
)
}

View file

@ -43,7 +43,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
StructField("bigintType", LongType, nullable = false),
StructField("tinyintType", ByteType, nullable = false),
StructField("decimalType", DecimalType.Unlimited, nullable = false),
StructField("fixedDecimalType", DecimalType(5,1), nullable = false),
StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
StructField("binaryType", BinaryType, nullable = false),
StructField("booleanType", BooleanType, nullable = false),
StructField("smallIntType", ShortType, nullable = false),
@ -51,8 +51,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
StructField("mapType", MapType(StringType, StringType)),
StructField("arrayType", ArrayType(StringType)),
StructField("structType",
StructType(StructField("f1",StringType) ::
(StructField("f2",IntegerType)) :: Nil
StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil
)
)
))

View file

@ -154,7 +154,7 @@ class FilteredScanSuite extends DataSourceTest {
sqlTest(
"SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)",
Seq(1,3,5).map(i => Row(i, i * 2)))
Seq(1, 3, 5).map(i => Row(i, i * 2)))
sqlTest(
"SELECT a, b FROM oneToTenFiltered WHERE A = 1",