From 7fcbb2ccaf50d7cb1dc68ff0c271737a3a59253e Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 18 May 2015 21:43:12 -0700 Subject: [PATCH] [SPARK-7150] SparkContext.range() and SQLContext.range() This PR is based on #6081, thanks adrian-wang. Closes #6081 Author: Daoyuan Wang Author: Davies Liu Closes #6230 from davies/range and squashes the following commits: d3ce5fe [Davies Liu] add tests 789eda5 [Davies Liu] add range() in Python 4590208 [Davies Liu] Merge commit 'refs/pull/6081/head' of github.com:apache/spark into range cbf5200 [Daoyuan Wang] let's add python support in a separate PR f45e3b2 [Daoyuan Wang] remove redundant toLong 617da76 [Daoyuan Wang] fix safe marge for corner cases 867c417 [Daoyuan Wang] fix 13dbe84 [Daoyuan Wang] update bd998ba [Daoyuan Wang] update comments d3a0c1b [Daoyuan Wang] add range api() (cherry picked from commit c2437de1899e09894df4ec27adfaa7fac158fd3a) Signed-off-by: Reynold Xin --- .../scala/org/apache/spark/SparkContext.scala | 72 +++++++++++++++++++ python/pyspark/context.py | 16 +++++ python/pyspark/sql/context.py | 20 ++++++ python/pyspark/sql/tests.py | 5 ++ python/pyspark/tests.py | 5 ++ .../org/apache/spark/sql/SQLContext.scala | 31 ++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 40 +++++++++++ 7 files changed, 189 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f78fbaf33f..3fe3dc5e30 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -697,6 +697,78 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } + /** + * Creates a new RDD[Long] containing elements from `start` to `end`(exclusive), increased by + * `step` every element. + * + * @note if we need to cache this RDD, we should make sure each partition does not exceed limit. + * + * @param start the start value. + * @param end the end value. + * @param step the incremental step + * @param numSlices the partition number of the new RDD. + * @return + */ + def range( + start: Long, + end: Long, + step: Long = 1, + numSlices: Int = defaultParallelism): RDD[Long] = withScope { + assertNotStopped() + // when step is 0, range will run infinitely + require(step != 0, "step cannot be 0") + val numElements: BigInt = { + val safeStart = BigInt(start) + val safeEnd = BigInt(end) + if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) { + (safeEnd - safeStart) / step + } else { + // the remainder has the same sign with range, could add 1 more + (safeEnd - safeStart) / step + 1 + } + } + parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => { + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue + } + val safePartitionStart = getSafeMargin(partitionStart) + val safePartitionEnd = getSafeMargin(partitionEnd) + + new Iterator[Long] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + + override def hasNext = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next() = { + val ret = number + number += step + if (number < ret ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step + // back, we are pretty sure that we have an overflow. + overflow = true + } + ret + } + } + }) + } + /** Distribute a local Scala collection to form an RDD. * * This method is identical to `parallelize`. diff --git a/python/pyspark/context.py b/python/pyspark/context.py index d25ee85523..1f2b40b29f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -319,6 +319,22 @@ class SparkContext(object): with SparkContext._lock: SparkContext._active_spark_context = None + def range(self, start, end, step=1, numSlices=None): + """ + Create a new RDD of int containing elements from `start` to `end` + (exclusive), increased by `step` every element. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numSlices: the number of partitions of the new RDD + :return: An RDD of int + + >>> sc.range(1, 7, 2).collect() + [1, 3, 5] + """ + return self.parallelize(xrange(start, end, step), numSlices) + def parallelize(self, c, numSlices=None): """ Distribute a local Python collection to form an RDD. Using xrange diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 0bde719124..9f26d13235 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -122,6 +122,26 @@ class SQLContext(object): """Returns a :class:`UDFRegistration` for UDF registration.""" return UDFRegistration(self) + def range(self, start, end, step=1, numPartitions=None): + """ + Create a :class:`DataFrame` with single LongType column named `id`, + containing elements in a range from `start` to `end` (exclusive) with + step value `step`. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numPartitions: the number of partitions of the DataFrame + :return: A new DataFrame + + >>> sqlContext.range(1, 7, 2).collect() + [Row(id=1), Row(id=3), Row(id=5)] + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + return DataFrame(jdf, self) + @ignore_unicode_prefix def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d37c5dbed7..84ae36f2fd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -117,6 +117,11 @@ class SQLTests(ReusedPySparkTestCase): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_range(self): + self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) + self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) + self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) + def test_explode(self): from pyspark.sql.functions import explode d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 5e023f6c53..d8e319994c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -444,6 +444,11 @@ class AddFileTests(PySparkTestCase): class RDDTests(ReusedPySparkTestCase): + def test_range(self): + self.assertEqual(self.sc.range(1, 1).count(), 0) + self.assertEqual(self.sc.range(1, 0, -1).count(), 1) + self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) + def test_id(self): rdd = self.sc.parallelize(range(10)) id = rdd.id() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ac1a800219..316ef7d588 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -684,6 +684,37 @@ class SQLContext(@transient val sparkContext: SparkContext) catalog.unregisterTable(Seq(tableName)) } + /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end`(exclusive) with step value 1. + * + * @since 1.4.0 + * @group dataframe + */ + @Experimental + def range(start: Long, end: Long): DataFrame = { + createDataFrame( + sparkContext.range(start, end).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) + } + + /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end`(exclusive) with an step value, with partition number + * specified. + * + * @since 1.4.0 + * @group dataframe + */ + @Experimental + def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { + createDataFrame( + sparkContext.range(start, end, step, numPartitions).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) + } + /** * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 054b23dba8..f05d059d44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -532,4 +532,44 @@ class DataFrameSuite extends QueryTest { val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project] assert(!p.child.isInstanceOf[Project]) } + + test("SPARK-7150 range api") { + // numSlice is greater than length + val res1 = TestSQLContext.range(0, 10, 1, 15).select("id") + assert(res1.count == 10) + assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res2 = TestSQLContext.range(3, 15, 3, 2).select("id") + assert(res2.count == 4) + assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) + + val res3 = TestSQLContext.range(1, -2).select("id") + assert(res3.count == 0) + + // start is positive, end is negative, step is negative + val res4 = TestSQLContext.range(1, -2, -2, 6).select("id") + assert(res4.count == 2) + assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) + + // start, end, step are negative + val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id") + assert(res5.count == 3) + assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) + + // start, end are negative, step is positive + val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id") + assert(res6.count == 2) + assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) + + val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id") + assert(res7.count == 0) + + val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + assert(res8.count == 3) + assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) + + val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + assert(res9.count == 2) + assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + } }