diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c2fa6c8738..4eaa8d9c57 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -755,8 +755,6 @@ class DataFrame(object): jdf = self._jdf.groupBy(self._jcols(*cols)) return GroupedData(jdf, self.sql_ctx) - groupby = groupBy - def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy.agg()``). @@ -793,6 +791,36 @@ class DataFrame(object): """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + def dropDuplicates(self, subset=None): + """Return a new :class:`DataFrame` with duplicate rows removed, + optionally only considering certain columns. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([ \ + Row(name='Alice', age=5, height=80), \ + Row(name='Alice', age=5, height=80), \ + Row(name='Alice', age=10, height=80)]).toDF() + >>> df.dropDuplicates().show() + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 5| 80|Alice| + | 10| 80|Alice| + +---+------+-----+ + + >>> df.dropDuplicates(['name', 'height']).show() + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 5| 80|Alice| + +---+------+-----+ + """ + if subset is None: + jdf = self._jdf.dropDuplicates() + else: + jdf = self._jdf.dropDuplicates(self._jseq(subset)) + return DataFrame(jdf, self.sql_ctx) + def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. @@ -1012,6 +1040,10 @@ class DataFrame(object): import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) + # Pandas compatibility + groupby = groupBy + drop_duplicates = dropDuplicates + # Having SchemaRDD for backward compatibility (for docs) class SchemaRDD(DataFrame): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 2472999de3..265a61592b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.sql.DriverManager - import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -42,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails -import org.apache.spark.sql.json.{JacksonGenerator, JsonRDD} +import org.apache.spark.sql.json.JacksonGenerator import org.apache.spark.sql.types._ import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils @@ -932,6 +931,40 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. + * This is an alias for `distinct`. + * @group dfops + */ + def dropDuplicates(): DataFrame = dropDuplicates(this.columns) + + /** + * (Scala-specific) Returns a new [[DataFrame]] with duplicate rows removed, considering only + * the subset of columns. + * + * @group dfops + */ + def dropDuplicates(colNames: Seq[String]): DataFrame = { + val groupCols = colNames.map(resolve) + val groupColExprIds = groupCols.map(_.exprId) + val aggCols = logicalPlan.output.map { attr => + if (groupColExprIds.contains(attr.exprId)) { + attr + } else { + Alias(First(attr), attr.name)() + } + } + Aggregate(groupCols, aggCols, logicalPlan) + } + + /** + * Returns a new [[DataFrame]] with duplicate rows removed, considering only + * the subset of columns. + * + * @group dfops + */ + def dropDuplicates(colNames: Array[String]): DataFrame = dropDuplicates(colNames.toSeq) + /** * Computes statistics for numeric columns, including count, mean, stddev, min, and max. * If no columns are given, this function computes statistics for all numerical columns. @@ -1089,6 +1122,7 @@ class DataFrame private[sql]( /** * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. + * This is an alias for `dropDuplicates`. * @group dfops */ override def distinct: DataFrame = Distinct(logicalPlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7552c12881..2ade955864 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -457,4 +457,39 @@ class DataFrameSuite extends QueryTest { assert(complexData.filter(complexData("m")("1") === 1).count() == 1) assert(complexData.filter(complexData("s")("key") === 1).count() == 1) } + + test("SPARK-7324 dropDuplicates") { + val testData = TestSQLContext.sparkContext.parallelize( + (2, 1, 2) :: (1, 1, 1) :: + (1, 2, 1) :: (2, 1, 2) :: + (2, 2, 2) :: (2, 2, 1) :: + (2, 1, 1) :: (1, 1, 2) :: + (1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2") + + checkAnswer( + testData.dropDuplicates(), + Seq(Row(2, 1, 2), Row(1, 1, 1), Row(1, 2, 1), + Row(2, 2, 2), Row(2, 1, 1), Row(2, 2, 1), + Row(1, 1, 2), Row(1, 2, 2))) + + checkAnswer( + testData.dropDuplicates(Seq("key", "value1")), + Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) + + checkAnswer( + testData.dropDuplicates(Seq("value1", "value2")), + Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) + + checkAnswer( + testData.dropDuplicates(Seq("key")), + Seq(Row(2, 1, 2), Row(1, 1, 1))) + + checkAnswer( + testData.dropDuplicates(Seq("value1")), + Seq(Row(2, 1, 2), Row(1, 2, 1))) + + checkAnswer( + testData.dropDuplicates(Seq("value2")), + Seq(Row(2, 1, 2), Row(1, 1, 1))) + } }