[SPARK-3500] [SQL] use JavaSchemaRDD as SchemaRDD._jschema_rdd
Currently, SchemaRDD._jschema_rdd is SchemaRDD, the Scala API (coalesce(), repartition()) can not been called in Python easily, there is no way to specify the implicit parameter `ord`. The _jrdd is an JavaRDD, so _jschema_rdd should also be JavaSchemaRDD. In this patch, change _schema_rdd to JavaSchemaRDD, also added an assert for it. If some methods are missing from JavaSchemaRDD, then it's called by _schema_rdd.baseSchemaRDD().xxx(). BTW, Do we need JavaSQLContext? Author: Davies Liu <davies.liu@gmail.com> Closes #2369 from davies/fix_schemardd and squashes the following commits: abee159 [Davies Liu] use JavaSchemaRDD as SchemaRDD._jschema_rdd
This commit is contained in:
parent
71af030b46
commit
885d1621bc
|
@ -1122,7 +1122,7 @@ class SQLContext(object):
|
|||
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
|
||||
jrdd = self._pythonToJava(rdd._jrdd, batched)
|
||||
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
|
||||
return SchemaRDD(srdd, self)
|
||||
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
|
||||
|
||||
def registerRDDAsTable(self, rdd, tableName):
|
||||
"""Registers the given RDD as a temporary table in the catalog.
|
||||
|
@ -1134,8 +1134,8 @@ class SQLContext(object):
|
|||
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
|
||||
"""
|
||||
if (rdd.__class__ is SchemaRDD):
|
||||
jschema_rdd = rdd._jschema_rdd
|
||||
self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
|
||||
srdd = rdd._jschema_rdd.baseSchemaRDD()
|
||||
self._ssql_ctx.registerRDDAsTable(srdd, tableName)
|
||||
else:
|
||||
raise ValueError("Can only register SchemaRDD as table")
|
||||
|
||||
|
@ -1151,7 +1151,7 @@ class SQLContext(object):
|
|||
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
|
||||
True
|
||||
"""
|
||||
jschema_rdd = self._ssql_ctx.parquetFile(path)
|
||||
jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
|
||||
return SchemaRDD(jschema_rdd, self)
|
||||
|
||||
def jsonFile(self, path, schema=None):
|
||||
|
@ -1207,11 +1207,11 @@ class SQLContext(object):
|
|||
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
|
||||
"""
|
||||
if schema is None:
|
||||
jschema_rdd = self._ssql_ctx.jsonFile(path)
|
||||
srdd = self._ssql_ctx.jsonFile(path)
|
||||
else:
|
||||
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
|
||||
jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype)
|
||||
return SchemaRDD(jschema_rdd, self)
|
||||
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
|
||||
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
|
||||
|
||||
def jsonRDD(self, rdd, schema=None):
|
||||
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
|
||||
|
@ -1275,11 +1275,11 @@ class SQLContext(object):
|
|||
keyed._bypass_serializer = True
|
||||
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
|
||||
if schema is None:
|
||||
jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
|
||||
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
|
||||
else:
|
||||
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
|
||||
jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
|
||||
return SchemaRDD(jschema_rdd, self)
|
||||
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
|
||||
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
|
||||
|
||||
def sql(self, sqlQuery):
|
||||
"""Return a L{SchemaRDD} representing the result of the given query.
|
||||
|
@ -1290,7 +1290,7 @@ class SQLContext(object):
|
|||
>>> srdd2.collect()
|
||||
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
|
||||
"""
|
||||
return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
|
||||
return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self)
|
||||
|
||||
def table(self, tableName):
|
||||
"""Returns the specified table as a L{SchemaRDD}.
|
||||
|
@ -1301,7 +1301,7 @@ class SQLContext(object):
|
|||
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
|
||||
True
|
||||
"""
|
||||
return SchemaRDD(self._ssql_ctx.table(tableName), self)
|
||||
return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self)
|
||||
|
||||
def cacheTable(self, tableName):
|
||||
"""Caches the specified table in-memory."""
|
||||
|
@ -1353,7 +1353,7 @@ class HiveContext(SQLContext):
|
|||
warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" +
|
||||
"default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
|
||||
DeprecationWarning)
|
||||
return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
|
||||
return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self)
|
||||
|
||||
def hql(self, hqlQuery):
|
||||
"""
|
||||
|
@ -1524,6 +1524,8 @@ class SchemaRDD(RDD):
|
|||
def __init__(self, jschema_rdd, sql_ctx):
|
||||
self.sql_ctx = sql_ctx
|
||||
self._sc = sql_ctx._sc
|
||||
clsName = jschema_rdd.getClass().getName()
|
||||
assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD"
|
||||
self._jschema_rdd = jschema_rdd
|
||||
self._id = None
|
||||
self.is_cached = False
|
||||
|
@ -1540,7 +1542,7 @@ class SchemaRDD(RDD):
|
|||
L{pyspark.rdd.RDD} super class (map, filter, etc.).
|
||||
"""
|
||||
if not hasattr(self, '_lazy_jrdd'):
|
||||
self._lazy_jrdd = self._jschema_rdd.javaToPython()
|
||||
self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
|
||||
return self._lazy_jrdd
|
||||
|
||||
def id(self):
|
||||
|
@ -1598,7 +1600,7 @@ class SchemaRDD(RDD):
|
|||
def schema(self):
|
||||
"""Returns the schema of this SchemaRDD (represented by
|
||||
a L{StructType})."""
|
||||
return _parse_datatype_string(self._jschema_rdd.schema().toString())
|
||||
return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
|
||||
|
||||
def schemaString(self):
|
||||
"""Returns the output schema in the tree format."""
|
||||
|
@ -1649,8 +1651,6 @@ class SchemaRDD(RDD):
|
|||
rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
|
||||
|
||||
schema = self.schema()
|
||||
import pickle
|
||||
pickle.loads(pickle.dumps(schema))
|
||||
|
||||
def applySchema(_, it):
|
||||
cls = _create_cls(schema)
|
||||
|
@ -1687,10 +1687,8 @@ class SchemaRDD(RDD):
|
|||
|
||||
def getCheckpointFile(self):
|
||||
checkpointFile = self._jschema_rdd.getCheckpointFile()
|
||||
if checkpointFile.isDefined():
|
||||
if checkpointFile.isPresent():
|
||||
return checkpointFile.get()
|
||||
else:
|
||||
return None
|
||||
|
||||
def coalesce(self, numPartitions, shuffle=False):
|
||||
rdd = self._jschema_rdd.coalesce(numPartitions, shuffle)
|
||||
|
|
|
@ -607,6 +607,34 @@ class TestSQL(PySparkTestCase):
|
|||
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
|
||||
self.assertEqual("", res[0])
|
||||
|
||||
def test_basic_functions(self):
|
||||
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
||||
srdd = self.sqlCtx.jsonRDD(rdd)
|
||||
srdd.count()
|
||||
srdd.collect()
|
||||
srdd.schemaString()
|
||||
srdd.schema()
|
||||
|
||||
# cache and checkpoint
|
||||
self.assertFalse(srdd.is_cached)
|
||||
srdd.persist()
|
||||
srdd.unpersist()
|
||||
srdd.cache()
|
||||
self.assertTrue(srdd.is_cached)
|
||||
self.assertFalse(srdd.isCheckpointed())
|
||||
self.assertEqual(None, srdd.getCheckpointFile())
|
||||
|
||||
srdd = srdd.coalesce(2, True)
|
||||
srdd = srdd.repartition(3)
|
||||
srdd = srdd.distinct()
|
||||
srdd.intersection(srdd)
|
||||
self.assertEqual(2, srdd.count())
|
||||
|
||||
srdd.registerTempTable("temp")
|
||||
srdd = self.sqlCtx.sql("select foo from temp")
|
||||
srdd.count()
|
||||
srdd.collect()
|
||||
|
||||
|
||||
class TestIO(PySparkTestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue