[SPARK-6911] [SQL] improve accessor for nested types
Support access columns by index in Python: ``` >>> df[df[0] > 3].collect() [Row(age=5, name=u'Bob')] ``` Access items in ArrayType or MapType ``` >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() >>> df.select(df.l[0], df.d["key"]).show() ``` Access field in StructType ``` >>> df.select(df.r.getField("b")).show() >>> df.select(df.r.a).show() ``` Author: Davies Liu <davies@databricks.com> Closes #5513 from davies/access and squashes the following commits: e04d5a0 [Davies Liu] Update run-tests-jenkins 7ada9eb [Davies Liu] update timeout d125ac4 [Davies Liu] check column name, improve scala tests 6b62540 [Davies Liu] fix test db15b42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into access 6c32e79 [Davies Liu] add scala tests 11f1df3 [Davies Liu] improve accessor for nested types
This commit is contained in:
parent
5fe4343352
commit
6183b5e2ca
|
@ -563,16 +563,23 @@ class DataFrame(object):
|
||||||
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
|
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
|
||||||
>>> df[ df.age > 3 ].collect()
|
>>> df[ df.age > 3 ].collect()
|
||||||
[Row(age=5, name=u'Bob')]
|
[Row(age=5, name=u'Bob')]
|
||||||
|
>>> df[df[0] > 3].collect()
|
||||||
|
[Row(age=5, name=u'Bob')]
|
||||||
"""
|
"""
|
||||||
if isinstance(item, basestring):
|
if isinstance(item, basestring):
|
||||||
|
if item not in self.columns:
|
||||||
|
raise IndexError("no such column: %s" % item)
|
||||||
jc = self._jdf.apply(item)
|
jc = self._jdf.apply(item)
|
||||||
return Column(jc)
|
return Column(jc)
|
||||||
elif isinstance(item, Column):
|
elif isinstance(item, Column):
|
||||||
return self.filter(item)
|
return self.filter(item)
|
||||||
elif isinstance(item, list):
|
elif isinstance(item, (list, tuple)):
|
||||||
return self.select(*item)
|
return self.select(*item)
|
||||||
|
elif isinstance(item, int):
|
||||||
|
jc = self._jdf.apply(self.columns[item])
|
||||||
|
return Column(jc)
|
||||||
else:
|
else:
|
||||||
raise IndexError("unexpected index: %s" % item)
|
raise TypeError("unexpected type: %s" % type(item))
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
"""Returns the :class:`Column` denoted by ``name``.
|
"""Returns the :class:`Column` denoted by ``name``.
|
||||||
|
@ -580,8 +587,8 @@ class DataFrame(object):
|
||||||
>>> df.select(df.age).collect()
|
>>> df.select(df.age).collect()
|
||||||
[Row(age=2), Row(age=5)]
|
[Row(age=2), Row(age=5)]
|
||||||
"""
|
"""
|
||||||
if name.startswith("__"):
|
if name not in self.columns:
|
||||||
raise AttributeError(name)
|
raise AttributeError("No such column: %s" % name)
|
||||||
jc = self._jdf.apply(name)
|
jc = self._jdf.apply(name)
|
||||||
return Column(jc)
|
return Column(jc)
|
||||||
|
|
||||||
|
@ -1093,7 +1100,39 @@ class Column(object):
|
||||||
# container operators
|
# container operators
|
||||||
__contains__ = _bin_op("contains")
|
__contains__ = _bin_op("contains")
|
||||||
__getitem__ = _bin_op("getItem")
|
__getitem__ = _bin_op("getItem")
|
||||||
getField = _bin_op("getField", "An expression that gets a field by name in a StructField.")
|
|
||||||
|
def getItem(self, key):
|
||||||
|
"""An expression that gets an item at position `ordinal` out of a list,
|
||||||
|
or gets an item by key out of a dict.
|
||||||
|
|
||||||
|
>>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
|
||||||
|
>>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
|
||||||
|
l[0] d[key]
|
||||||
|
1 value
|
||||||
|
>>> df.select(df.l[0], df.d["key"]).show()
|
||||||
|
l[0] d[key]
|
||||||
|
1 value
|
||||||
|
"""
|
||||||
|
return self[key]
|
||||||
|
|
||||||
|
def getField(self, name):
|
||||||
|
"""An expression that gets a field by name in a StructField.
|
||||||
|
|
||||||
|
>>> from pyspark.sql import Row
|
||||||
|
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
|
||||||
|
>>> df.select(df.r.getField("b")).show()
|
||||||
|
r.b
|
||||||
|
b
|
||||||
|
>>> df.select(df.r.a).show()
|
||||||
|
r.a
|
||||||
|
1
|
||||||
|
"""
|
||||||
|
return Column(self._jc.getField(name))
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item.startswith("__"):
|
||||||
|
raise AttributeError(item)
|
||||||
|
return self.getField(item)
|
||||||
|
|
||||||
# string methods
|
# string methods
|
||||||
rlike = _bin_op("rlike")
|
rlike = _bin_op("rlike")
|
||||||
|
|
|
@ -426,6 +426,24 @@ class SQLTests(ReusedPySparkTestCase):
|
||||||
pydoc.render_doc(df.foo)
|
pydoc.render_doc(df.foo)
|
||||||
pydoc.render_doc(df.take(1))
|
pydoc.render_doc(df.take(1))
|
||||||
|
|
||||||
|
def test_access_column(self):
|
||||||
|
df = self.df
|
||||||
|
self.assertTrue(isinstance(df.key, Column))
|
||||||
|
self.assertTrue(isinstance(df['key'], Column))
|
||||||
|
self.assertTrue(isinstance(df[0], Column))
|
||||||
|
self.assertRaises(IndexError, lambda: df[2])
|
||||||
|
self.assertRaises(IndexError, lambda: df["bad_key"])
|
||||||
|
self.assertRaises(TypeError, lambda: df[{}])
|
||||||
|
|
||||||
|
def test_access_nested_types(self):
|
||||||
|
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
|
||||||
|
self.assertEqual(1, df.select(df.l[0]).first()[0])
|
||||||
|
self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
|
||||||
|
self.assertEqual(1, df.select(df.r.a).first()[0])
|
||||||
|
self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
|
||||||
|
self.assertEqual("v", df.select(df.d["k"]).first()[0])
|
||||||
|
self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
|
||||||
|
|
||||||
def test_infer_long_type(self):
|
def test_infer_long_type(self):
|
||||||
longrow = [Row(f1='a', f2=100000000000000)]
|
longrow = [Row(f1='a', f2=100000000000000)]
|
||||||
df = self.sc.parallelize(longrow).toDF()
|
df = self.sc.parallelize(longrow).toDF()
|
||||||
|
|
|
@ -515,14 +515,15 @@ class Column(protected[sql] val expr: Expression) extends Logging {
|
||||||
def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
|
def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An expression that gets an item at position `ordinal` out of an array.
|
* An expression that gets an item at position `ordinal` out of an array,
|
||||||
|
* or gets a value by key `key` in a [[MapType]].
|
||||||
*
|
*
|
||||||
* @group expr_ops
|
* @group expr_ops
|
||||||
*/
|
*/
|
||||||
def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
|
def getItem(key: Any): Column = GetItem(expr, Literal(key))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An expression that gets a field by name in a [[StructField]].
|
* An expression that gets a field by name in a [[StructType]].
|
||||||
*
|
*
|
||||||
* @group expr_ops
|
* @group expr_ops
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -86,6 +86,12 @@ class DataFrameSuite extends QueryTest {
|
||||||
TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
|
TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("access complex data") {
|
||||||
|
assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1)
|
||||||
|
assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1)
|
||||||
|
assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1)
|
||||||
|
}
|
||||||
|
|
||||||
test("table scan") {
|
test("table scan") {
|
||||||
checkAnswer(
|
checkAnswer(
|
||||||
testData,
|
testData,
|
||||||
|
|
|
@ -20,9 +20,8 @@ package org.apache.spark.sql
|
||||||
import java.sql.Timestamp
|
import java.sql.Timestamp
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.plans.logical
|
import org.apache.spark.sql.catalyst.plans.logical
|
||||||
import org.apache.spark.sql.functions._
|
|
||||||
import org.apache.spark.sql.test._
|
|
||||||
import org.apache.spark.sql.test.TestSQLContext.implicits._
|
import org.apache.spark.sql.test.TestSQLContext.implicits._
|
||||||
|
import org.apache.spark.sql.test._
|
||||||
|
|
||||||
|
|
||||||
case class TestData(key: Int, value: String)
|
case class TestData(key: Int, value: String)
|
||||||
|
@ -199,11 +198,11 @@ object TestData {
|
||||||
Salary(1, 1000.0) :: Nil).toDF()
|
Salary(1, 1000.0) :: Nil).toDF()
|
||||||
salary.registerTempTable("salary")
|
salary.registerTempTable("salary")
|
||||||
|
|
||||||
case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean)
|
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
|
||||||
val complexData =
|
val complexData =
|
||||||
TestSQLContext.sparkContext.parallelize(
|
TestSQLContext.sparkContext.parallelize(
|
||||||
ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true)
|
ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true)
|
||||||
:: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false)
|
:: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false)
|
||||||
:: Nil).toDF()
|
:: Nil).toDF()
|
||||||
complexData.registerTempTable("complexData")
|
complexData.registerTempTable("complexData")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue