[SPARK-7982][SQL] DataFrame.stat.crosstab should use 0 instead of null for pairs that don't appear
Author: Reynold Xin <rxin@databricks.com>
Closes #6566 from rxin/crosstab and squashes the following commits:
e0ace1c [Reynold Xin] [SPARK-7982][SQL] DataFrame.stat.crosstab should use 0 instead of null for pairs that don't appear
(cherry picked from commit 6396cc0303
)
Signed-off-by: Reynold Xin <rxin@databricks.com>
This commit is contained in:
parent
cbfb682ab9
commit
efc0e05323
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.execution.stat
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.{Column, DataFrame}
|
||||
import org.apache.spark.sql.{Row, Column, DataFrame}
|
||||
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
|
||||
import org.apache.spark.sql.functions._
|
||||
|
@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging {
|
|||
s"exceed 1e4. Currently $columnSize")
|
||||
val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
|
||||
val countsRow = new GenericMutableRow(columnSize + 1)
|
||||
rows.foreach { row =>
|
||||
rows.foreach { (row: Row) =>
|
||||
// row.get(0) is column 1
|
||||
// row.get(1) is column 2
|
||||
// row.get(3) is the frequency
|
||||
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
|
||||
}
|
||||
// the value of col1 is the first value, the rest are the counts
|
||||
|
@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging {
|
|||
val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
|
||||
val schema = StructType(StructField(tableName, StringType) +: headerNames)
|
||||
|
||||
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
|
||||
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -74,10 +74,10 @@ class DataFrameStatSuite extends FunSuite {
|
|||
val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
|
||||
assert(rows(0).get(0).toString === "0")
|
||||
assert(rows(0).getLong(1) === 2L)
|
||||
assert(rows(0).get(2) === null)
|
||||
assert(rows(0).get(2) === 0L)
|
||||
assert(rows(1).get(0).toString === "1")
|
||||
assert(rows(1).getLong(1) === 1L)
|
||||
assert(rows(1).get(2) === null)
|
||||
assert(rows(1).get(2) === 0L)
|
||||
assert(rows(2).get(0).toString === "2")
|
||||
assert(rows(2).getLong(1) === 2L)
|
||||
assert(rows(2).getLong(2) === 1L)
|
||||
|
|
Loading…
Reference in a new issue