[SPARK-7150] SparkContext.range() and SQLContext.range()
This PR is based on #6081, thanks adrian-wang.
Closes #6081
Author: Daoyuan Wang <daoyuan.wang@intel.com>
Author: Davies Liu <davies@databricks.com>
Closes #6230 from davies/range and squashes the following commits:
d3ce5fe [Davies Liu] add tests
789eda5 [Davies Liu] add range() in Python
4590208 [Davies Liu] Merge commit 'refs/pull/6081/head' of github.com:apache/spark into range
cbf5200 [Daoyuan Wang] let's add python support in a separate PR
f45e3b2 [Daoyuan Wang] remove redundant toLong
617da76 [Daoyuan Wang] fix safe marge for corner cases
867c417 [Daoyuan Wang] fix
13dbe84 [Daoyuan Wang] update
bd998ba [Daoyuan Wang] update comments
d3a0c1b [Daoyuan Wang] add range api()
(cherry picked from commit c2437de189
)
Signed-off-by: Reynold Xin <rxin@databricks.com>
This commit is contained in:
parent
9d0b7fb714
commit
7fcbb2ccaf
|
@ -697,6 +697,78 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
|
||||||
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
|
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new RDD[Long] containing elements from `start` to `end`(exclusive), increased by
|
||||||
|
* `step` every element.
|
||||||
|
*
|
||||||
|
* @note if we need to cache this RDD, we should make sure each partition does not exceed limit.
|
||||||
|
*
|
||||||
|
* @param start the start value.
|
||||||
|
* @param end the end value.
|
||||||
|
* @param step the incremental step
|
||||||
|
* @param numSlices the partition number of the new RDD.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def range(
|
||||||
|
start: Long,
|
||||||
|
end: Long,
|
||||||
|
step: Long = 1,
|
||||||
|
numSlices: Int = defaultParallelism): RDD[Long] = withScope {
|
||||||
|
assertNotStopped()
|
||||||
|
// when step is 0, range will run infinitely
|
||||||
|
require(step != 0, "step cannot be 0")
|
||||||
|
val numElements: BigInt = {
|
||||||
|
val safeStart = BigInt(start)
|
||||||
|
val safeEnd = BigInt(end)
|
||||||
|
if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) {
|
||||||
|
(safeEnd - safeStart) / step
|
||||||
|
} else {
|
||||||
|
// the remainder has the same sign with range, could add 1 more
|
||||||
|
(safeEnd - safeStart) / step + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => {
|
||||||
|
val partitionStart = (i * numElements) / numSlices * step + start
|
||||||
|
val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
|
||||||
|
def getSafeMargin(bi: BigInt): Long =
|
||||||
|
if (bi.isValidLong) {
|
||||||
|
bi.toLong
|
||||||
|
} else if (bi > 0) {
|
||||||
|
Long.MaxValue
|
||||||
|
} else {
|
||||||
|
Long.MinValue
|
||||||
|
}
|
||||||
|
val safePartitionStart = getSafeMargin(partitionStart)
|
||||||
|
val safePartitionEnd = getSafeMargin(partitionEnd)
|
||||||
|
|
||||||
|
new Iterator[Long] {
|
||||||
|
private[this] var number: Long = safePartitionStart
|
||||||
|
private[this] var overflow: Boolean = false
|
||||||
|
|
||||||
|
override def hasNext =
|
||||||
|
if (!overflow) {
|
||||||
|
if (step > 0) {
|
||||||
|
number < safePartitionEnd
|
||||||
|
} else {
|
||||||
|
number > safePartitionEnd
|
||||||
|
}
|
||||||
|
} else false
|
||||||
|
|
||||||
|
override def next() = {
|
||||||
|
val ret = number
|
||||||
|
number += step
|
||||||
|
if (number < ret ^ step < 0) {
|
||||||
|
// we have Long.MaxValue + Long.MaxValue < Long.MaxValue
|
||||||
|
// and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step
|
||||||
|
// back, we are pretty sure that we have an overflow.
|
||||||
|
overflow = true
|
||||||
|
}
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/** Distribute a local Scala collection to form an RDD.
|
/** Distribute a local Scala collection to form an RDD.
|
||||||
*
|
*
|
||||||
* This method is identical to `parallelize`.
|
* This method is identical to `parallelize`.
|
||||||
|
|
|
@ -319,6 +319,22 @@ class SparkContext(object):
|
||||||
with SparkContext._lock:
|
with SparkContext._lock:
|
||||||
SparkContext._active_spark_context = None
|
SparkContext._active_spark_context = None
|
||||||
|
|
||||||
|
def range(self, start, end, step=1, numSlices=None):
|
||||||
|
"""
|
||||||
|
Create a new RDD of int containing elements from `start` to `end`
|
||||||
|
(exclusive), increased by `step` every element.
|
||||||
|
|
||||||
|
:param start: the start value
|
||||||
|
:param end: the end value (exclusive)
|
||||||
|
:param step: the incremental step (default: 1)
|
||||||
|
:param numSlices: the number of partitions of the new RDD
|
||||||
|
:return: An RDD of int
|
||||||
|
|
||||||
|
>>> sc.range(1, 7, 2).collect()
|
||||||
|
[1, 3, 5]
|
||||||
|
"""
|
||||||
|
return self.parallelize(xrange(start, end, step), numSlices)
|
||||||
|
|
||||||
def parallelize(self, c, numSlices=None):
|
def parallelize(self, c, numSlices=None):
|
||||||
"""
|
"""
|
||||||
Distribute a local Python collection to form an RDD. Using xrange
|
Distribute a local Python collection to form an RDD. Using xrange
|
||||||
|
|
|
@ -122,6 +122,26 @@ class SQLContext(object):
|
||||||
"""Returns a :class:`UDFRegistration` for UDF registration."""
|
"""Returns a :class:`UDFRegistration` for UDF registration."""
|
||||||
return UDFRegistration(self)
|
return UDFRegistration(self)
|
||||||
|
|
||||||
|
def range(self, start, end, step=1, numPartitions=None):
|
||||||
|
"""
|
||||||
|
Create a :class:`DataFrame` with single LongType column named `id`,
|
||||||
|
containing elements in a range from `start` to `end` (exclusive) with
|
||||||
|
step value `step`.
|
||||||
|
|
||||||
|
:param start: the start value
|
||||||
|
:param end: the end value (exclusive)
|
||||||
|
:param step: the incremental step (default: 1)
|
||||||
|
:param numPartitions: the number of partitions of the DataFrame
|
||||||
|
:return: A new DataFrame
|
||||||
|
|
||||||
|
>>> sqlContext.range(1, 7, 2).collect()
|
||||||
|
[Row(id=1), Row(id=3), Row(id=5)]
|
||||||
|
"""
|
||||||
|
if numPartitions is None:
|
||||||
|
numPartitions = self._sc.defaultParallelism
|
||||||
|
jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
|
||||||
|
return DataFrame(jdf, self)
|
||||||
|
|
||||||
@ignore_unicode_prefix
|
@ignore_unicode_prefix
|
||||||
def registerFunction(self, name, f, returnType=StringType()):
|
def registerFunction(self, name, f, returnType=StringType()):
|
||||||
"""Registers a lambda function as a UDF so it can be used in SQL statements.
|
"""Registers a lambda function as a UDF so it can be used in SQL statements.
|
||||||
|
|
|
@ -117,6 +117,11 @@ class SQLTests(ReusedPySparkTestCase):
|
||||||
ReusedPySparkTestCase.tearDownClass()
|
ReusedPySparkTestCase.tearDownClass()
|
||||||
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
||||||
|
|
||||||
|
def test_range(self):
|
||||||
|
self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
|
||||||
|
self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
|
||||||
|
self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
|
||||||
|
|
||||||
def test_explode(self):
|
def test_explode(self):
|
||||||
from pyspark.sql.functions import explode
|
from pyspark.sql.functions import explode
|
||||||
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
|
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
|
||||||
|
|
|
@ -444,6 +444,11 @@ class AddFileTests(PySparkTestCase):
|
||||||
|
|
||||||
class RDDTests(ReusedPySparkTestCase):
|
class RDDTests(ReusedPySparkTestCase):
|
||||||
|
|
||||||
|
def test_range(self):
|
||||||
|
self.assertEqual(self.sc.range(1, 1).count(), 0)
|
||||||
|
self.assertEqual(self.sc.range(1, 0, -1).count(), 1)
|
||||||
|
self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2)
|
||||||
|
|
||||||
def test_id(self):
|
def test_id(self):
|
||||||
rdd = self.sc.parallelize(range(10))
|
rdd = self.sc.parallelize(range(10))
|
||||||
id = rdd.id()
|
id = rdd.id()
|
||||||
|
|
|
@ -684,6 +684,37 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
||||||
catalog.unregisterTable(Seq(tableName))
|
catalog.unregisterTable(Seq(tableName))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: Experimental ::
|
||||||
|
* Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
|
||||||
|
* in an range from `start` to `end`(exclusive) with step value 1.
|
||||||
|
*
|
||||||
|
* @since 1.4.0
|
||||||
|
* @group dataframe
|
||||||
|
*/
|
||||||
|
@Experimental
|
||||||
|
def range(start: Long, end: Long): DataFrame = {
|
||||||
|
createDataFrame(
|
||||||
|
sparkContext.range(start, end).map(Row(_)),
|
||||||
|
StructType(StructField("id", LongType, nullable = false) :: Nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: Experimental ::
|
||||||
|
* Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
|
||||||
|
* in an range from `start` to `end`(exclusive) with an step value, with partition number
|
||||||
|
* specified.
|
||||||
|
*
|
||||||
|
* @since 1.4.0
|
||||||
|
* @group dataframe
|
||||||
|
*/
|
||||||
|
@Experimental
|
||||||
|
def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = {
|
||||||
|
createDataFrame(
|
||||||
|
sparkContext.range(start, end, step, numPartitions).map(Row(_)),
|
||||||
|
StructType(StructField("id", LongType, nullable = false) :: Nil))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is
|
* Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is
|
||||||
* used for SQL parsing can be configured with 'spark.sql.dialect'.
|
* used for SQL parsing can be configured with 'spark.sql.dialect'.
|
||||||
|
|
|
@ -532,4 +532,44 @@ class DataFrameSuite extends QueryTest {
|
||||||
val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project]
|
val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project]
|
||||||
assert(!p.child.isInstanceOf[Project])
|
assert(!p.child.isInstanceOf[Project])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-7150 range api") {
|
||||||
|
// numSlice is greater than length
|
||||||
|
val res1 = TestSQLContext.range(0, 10, 1, 15).select("id")
|
||||||
|
assert(res1.count == 10)
|
||||||
|
assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
|
||||||
|
|
||||||
|
val res2 = TestSQLContext.range(3, 15, 3, 2).select("id")
|
||||||
|
assert(res2.count == 4)
|
||||||
|
assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
|
||||||
|
|
||||||
|
val res3 = TestSQLContext.range(1, -2).select("id")
|
||||||
|
assert(res3.count == 0)
|
||||||
|
|
||||||
|
// start is positive, end is negative, step is negative
|
||||||
|
val res4 = TestSQLContext.range(1, -2, -2, 6).select("id")
|
||||||
|
assert(res4.count == 2)
|
||||||
|
assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
|
||||||
|
|
||||||
|
// start, end, step are negative
|
||||||
|
val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id")
|
||||||
|
assert(res5.count == 3)
|
||||||
|
assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
|
||||||
|
|
||||||
|
// start, end are negative, step is positive
|
||||||
|
val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id")
|
||||||
|
assert(res6.count == 2)
|
||||||
|
assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
|
||||||
|
|
||||||
|
val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id")
|
||||||
|
assert(res7.count == 0)
|
||||||
|
|
||||||
|
val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
|
||||||
|
assert(res8.count == 3)
|
||||||
|
assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
|
||||||
|
|
||||||
|
val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
|
||||||
|
assert(res9.count == 2)
|
||||||
|
assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue