[SPARK-8803] handle special characters in elements in crosstab

cc rxin

Having back ticks or null as elements causes problems.
Since elements become column names, we have to drop them from the element as back ticks are special characters.
Having null throws exceptions, we could replace them with empty strings.

Handling back ticks should be improved for 1.5

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #7201 from brkyvz/weird-ct-elements and squashes the following commits:

e06b840 [Burak Yavuz] fix scalastyle
93a0d3f [Burak Yavuz] added tests for NaN and Infinity
9dba6ce [Burak Yavuz] address cr1
db71dbd [Burak Yavuz] handle special characters in elements in crosstab

(cherry picked from commit 9b23e92c72)
Signed-off-by: Reynold Xin <rxin@databricks.com>

Conflicts:
	sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
This commit is contained in:
Burak Yavuz 2015-07-02 22:10:24 -07:00 committed by Reynold Xin
parent f142867ece
commit ff76b33b67
4 changed files with 50 additions and 5 deletions

View file

@ -391,7 +391,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`. * Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
*/ */
private def fillCol[T](col: StructField, replacement: T): Column = { private def fillCol[T](col: StructField, replacement: T): Column = {
coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
} }
/** /**

View file

@ -78,6 +78,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* The first column of each row will be the distinct values of `col1` and the column names will * The first column of each row will be the distinct values of `col1` and the column names will
* be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts * be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts
* will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts. * will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts.
* Null elements will be replaced by "null", and back ticks will be dropped from elements if they
* exist.
*
* *
* @param col1 The name of the first column. Distinct items will make the first item of * @param col1 The name of the first column. Distinct items will make the first item of
* each row. * each row.

View file

@ -109,8 +109,12 @@ private[sql] object StatFunctions extends Logging {
logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " + logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
"the pairs. Please try reducing the amount of distinct items in your columns.") "the pairs. Please try reducing the amount of distinct items in your columns.")
} }
def cleanElement(element: Any): String = {
if (element == null) "null" else element.toString
}
// get the distinct values of column 2, so that we can make them the column names // get the distinct values of column 2, so that we can make them the column names
val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap val distinctCol2: Map[Any, Int] =
counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap
val columnSize = distinctCol2.size val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize") s"exceed 1e4. Currently $columnSize")
@ -120,15 +124,23 @@ private[sql] object StatFunctions extends Logging {
// row.get(0) is column 1 // row.get(0) is column 1
// row.get(1) is column 2 // row.get(1) is column 2
// row.get(2) is the frequency // row.get(2) is the frequency
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) val columnIndex = distinctCol2.get(cleanElement(row.get(1))).get
countsRow.setLong(columnIndex + 1, row.getLong(2))
} }
// the value of col1 is the first value, the rest are the counts // the value of col1 is the first value, the rest are the counts
countsRow.setString(0, col1Item.toString) countsRow.setString(0, cleanElement(col1Item.toString))
countsRow countsRow
}.toSeq }.toSeq
// Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept
// special keywords and `.`, wrap the column names in ``.
def cleanColumnName(name: String): String = {
name.replace("`", "")
}
// In the map, the column names (._1) are not ordered by the index (._2). This was the bug in // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in
// SPARK-8681. We need to explicitly sort by the column index and assign the column names. // SPARK-8681. We need to explicitly sort by the column index and assign the column names.
val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType)) val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r =>
StructField(cleanColumnName(r._1.toString), LongType)
}
val schema = StructType(StructField(tableName, StringType) +: headerNames) val schema = StructType(StructField(tableName, StringType) +: headerNames)
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)

View file

@ -85,6 +85,36 @@ class DataFrameStatSuite extends SparkFunSuite {
} }
} }
test("special crosstab elements (., '', null, ``)") {
val data = Seq(
("a", Double.NaN, "ho"),
(null, 2.0, "ho"),
("a.b", Double.NegativeInfinity, ""),
("b", Double.PositiveInfinity, "`ha`"),
("a", 1.0, null)
)
val df = data.toDF("1", "2", "3")
val ct1 = df.stat.crosstab("1", "2")
// column fields should be 1 + distinct elements of second column
assert(ct1.schema.fields.length === 6)
assert(ct1.collect().length === 4)
val ct2 = df.stat.crosstab("1", "3")
assert(ct2.schema.fields.length === 5)
assert(ct2.schema.fieldNames.contains("ha"))
assert(ct2.collect().length === 4)
val ct3 = df.stat.crosstab("3", "2")
assert(ct3.schema.fields.length === 6)
assert(ct3.schema.fieldNames.contains("NaN"))
assert(ct3.schema.fieldNames.contains("Infinity"))
assert(ct3.schema.fieldNames.contains("-Infinity"))
assert(ct3.collect().length === 4)
val ct4 = df.stat.crosstab("3", "1")
assert(ct4.schema.fields.length === 5)
assert(ct4.schema.fieldNames.contains("null"))
assert(ct4.schema.fieldNames.contains("a.b"))
assert(ct4.collect().length === 4)
}
test("Frequent Items") { test("Frequent Items") {
val rows = Seq.tabulate(1000) { i => val rows = Seq.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)