[SPARK-7157][SQL] add sampleBy to DataFrame
This was previously committed but then reverted due to test failures (see #6769). Author: Xiangrui Meng <meng@databricks.com> Closes #7755 from rxin/SPARK-7157 and squashes the following commits: fbf9044 [Xiangrui Meng] fix python test 542bd37 [Xiangrui Meng] update test 604fe6d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 f051afd [Xiangrui Meng] use udf instead of building expression f4e9425 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 8fb990b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 103beb3 [Xiangrui Meng] add Java-friendly sampleBy 991f26f [Xiangrui Meng] fix seed 4a14834 [Xiangrui Meng] move sampleBy to stat 832f7cc [Xiangrui Meng] add sampleBy to DataFrame
This commit is contained in:
parent
ca71cc8c8b
commit
df32669514
|
@ -441,6 +441,42 @@ class DataFrame(object):
|
|||
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
|
||||
return DataFrame(rdd, self.sql_ctx)
|
||||
|
||||
@since(1.5)
|
||||
def sampleBy(self, col, fractions, seed=None):
|
||||
"""
|
||||
Returns a stratified sample without replacement based on the
|
||||
fraction given on each stratum.
|
||||
|
||||
:param col: column that defines strata
|
||||
:param fractions:
|
||||
sampling fraction for each stratum. If a stratum is not
|
||||
specified, we treat its fraction as zero.
|
||||
:param seed: random seed
|
||||
:return: a new DataFrame that represents the stratified sample
|
||||
|
||||
>>> from pyspark.sql.functions import col
|
||||
>>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
|
||||
>>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
|
||||
>>> sampled.groupBy("key").count().orderBy("key").show()
|
||||
+---+-----+
|
||||
|key|count|
|
||||
+---+-----+
|
||||
| 0| 3|
|
||||
| 1| 8|
|
||||
+---+-----+
|
||||
|
||||
"""
|
||||
if not isinstance(col, str):
|
||||
raise ValueError("col must be a string, but got %r" % type(col))
|
||||
if not isinstance(fractions, dict):
|
||||
raise ValueError("fractions must be a dict but got %r" % type(fractions))
|
||||
for k, v in fractions.items():
|
||||
if not isinstance(k, (float, int, long, basestring)):
|
||||
raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
|
||||
fractions[k] = float(v)
|
||||
seed = seed if seed is not None else random.randint(0, sys.maxsize)
|
||||
return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
|
||||
|
||||
@since(1.4)
|
||||
def randomSplit(self, weights, seed=None):
|
||||
"""Randomly splits this :class:`DataFrame` with the provided weights.
|
||||
|
@ -1314,6 +1350,11 @@ class DataFrameStatFunctions(object):
|
|||
|
||||
freqItems.__doc__ = DataFrame.freqItems.__doc__
|
||||
|
||||
def sampleBy(self, col, fractions, seed=None):
|
||||
return self.df.sampleBy(col, fractions, seed)
|
||||
|
||||
sampleBy.__doc__ = DataFrame.sampleBy.__doc__
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
|
|
|
@ -17,6 +17,10 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.{util => ju, lang => jl}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.sql.execution.stat._
|
||||
|
||||
|
@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
|
|||
def freqItems(cols: Seq[String]): DataFrame = {
|
||||
FrequentItems.singlePassFreqItems(df, cols, 0.01)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a stratified sample without replacement based on the fraction given on each stratum.
|
||||
* @param col column that defines strata
|
||||
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
|
||||
* its fraction as zero.
|
||||
* @param seed random seed
|
||||
* @tparam T stratum type
|
||||
* @return a new [[DataFrame]] that represents the stratified sample
|
||||
*
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = {
|
||||
require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
|
||||
s"Fractions must be in [0, 1], but got $fractions.")
|
||||
import org.apache.spark.sql.functions.{rand, udf}
|
||||
val c = Column(col)
|
||||
val r = rand(seed)
|
||||
val f = udf { (stratum: Any, x: Double) =>
|
||||
x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0)
|
||||
}
|
||||
df.filter(f(c, r))
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a stratified sample without replacement based on the fraction given on each stratum.
|
||||
* @param col column that defines strata
|
||||
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
|
||||
* its fraction as zero.
|
||||
* @param seed random seed
|
||||
* @tparam T stratum type
|
||||
* @return a new [[DataFrame]] that represents the stratified sample
|
||||
*
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
|
||||
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -226,4 +226,13 @@ public class JavaDataFrameSuite {
|
|||
Double result = df.stat().cov("a", "b");
|
||||
Assert.assertTrue(Math.abs(result) < 1e-6);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSampleBy() {
|
||||
DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key"));
|
||||
DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
|
||||
Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
|
||||
Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)};
|
||||
Assert.assertArrayEquals(expected, actual);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,9 +21,9 @@ import java.util.Random
|
|||
|
||||
import org.scalatest.Matchers._
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.functions.col
|
||||
|
||||
class DataFrameStatSuite extends SparkFunSuite {
|
||||
class DataFrameStatSuite extends QueryTest {
|
||||
|
||||
private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
|
||||
import sqlCtx.implicits._
|
||||
|
@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite {
|
|||
val items2 = singleColResults.collect().head
|
||||
items2.getSeq[Double](0) should contain (-1.0)
|
||||
}
|
||||
|
||||
test("sampleBy") {
|
||||
val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
|
||||
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
|
||||
checkAnswer(
|
||||
sampled.groupBy("key").count().orderBy("key"),
|
||||
Seq(Row(0, 5), Row(1, 8)))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue