# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import pydoc import time import unittest from pyspark.sql import SparkSession, Row from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException, IllegalArgumentException from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils, have_pyarrow, have_pandas, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest class DataFrameTests(ReusedSQLTestCase): def test_range(self): self.assertEqual(self.spark.range(1, 1).count(), 0) self.assertEqual(self.spark.range(1, 0, -1).count(), 1) self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2) self.assertEqual(self.spark.range(-2).count(), 0) self.assertEqual(self.spark.range(3).count(), 3) def test_duplicated_column_names(self): df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) row = df.select('*').first() self.assertEqual(1, row[0]) self.assertEqual(2, row[1]) self.assertEqual("Row(c=1, c=2)", str(row)) # Cannot access columns self.assertRaises(AnalysisException, lambda: df.select(df[0]).first()) self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) def test_freqItems(self): vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] df = self.sc.parallelize(vals).toDF() items = df.stat.freqItems(("a", "b"), 0.4).collect()[0] self.assertTrue(1 in items[0]) self.assertTrue(-2.0 in items[1]) def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) # render_doc() reproduces the help() exception without printing output pydoc.render_doc(df) pydoc.render_doc(df.foo) pydoc.render_doc(df.take(1)) def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), StructField("age", IntegerType(), True), StructField("height", DoubleType(), True)]) # shouldn't drop a non-null row self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, 80.1)], schema).dropna().count(), 1) # dropping rows with a single null value self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna().count(), 0) self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), 0) # if how = 'all', only drop rows if all values are null self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), 1) self.assertEqual(self.spark.createDataFrame( [(None, None, None)], schema).dropna(how='all').count(), 0) # how and subset self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), 1) self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), 0) # threshold self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), 1) self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, None)], schema).dropna(thresh=2).count(), 0) # threshold and subset self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), 1) self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), 0) # thresh should take precedence over how self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna( how='any', thresh=2, subset=['name', 'age']).count(), 1) def test_fillna(self): schema = StructType([ StructField("name", StringType(), True), StructField("age", IntegerType(), True), StructField("height", DoubleType(), True), StructField("spy", BooleanType(), True)]) # fillna shouldn't change non-null values row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first() self.assertEqual(row.age, 10) # fillna with int row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.0) # fillna with double row = self.spark.createDataFrame( [(u'Alice', None, None, None)], schema).fillna(50.1).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.1) # fillna with bool row = self.spark.createDataFrame( [(u'Alice', None, None, None)], schema).fillna(True).first() self.assertEqual(row.age, None) self.assertEqual(row.spy, True) # fillna with string row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first() self.assertEqual(row.name, u"hello") self.assertEqual(row.age, None) # fillna with subset specified for numeric cols row = self.spark.createDataFrame( [(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first() self.assertEqual(row.name, None) self.assertEqual(row.age, 50) self.assertEqual(row.height, None) self.assertEqual(row.spy, None) # fillna with subset specified for string cols row = self.spark.createDataFrame( [(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() self.assertEqual(row.name, "haha") self.assertEqual(row.age, None) self.assertEqual(row.height, None) self.assertEqual(row.spy, None) # fillna with subset specified for bool cols row = self.spark.createDataFrame( [(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first() self.assertEqual(row.name, None) self.assertEqual(row.age, None) self.assertEqual(row.height, None) self.assertEqual(row.spy, True) # fillna with dictionary for boolean types row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first() self.assertEqual(row.a, True) def test_repartitionByRange_dataframe(self): schema = StructType([ StructField("name", StringType(), True), StructField("age", IntegerType(), True), StructField("height", DoubleType(), True)]) df1 = self.spark.createDataFrame( [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema) df2 = self.spark.createDataFrame( [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema) # test repartitionByRange(numPartitions, *cols) df3 = df1.repartitionByRange(2, "name", "age") self.assertEqual(df3.rdd.getNumPartitions(), 2) self.assertEqual(df3.rdd.first(), df2.rdd.first()) self.assertEqual(df3.rdd.take(3), df2.rdd.take(3)) # test repartitionByRange(numPartitions, *cols) df4 = df1.repartitionByRange(3, "name", "age") self.assertEqual(df4.rdd.getNumPartitions(), 3) self.assertEqual(df4.rdd.first(), df2.rdd.first()) self.assertEqual(df4.rdd.take(3), df2.rdd.take(3)) # test repartitionByRange(*cols) df5 = df1.repartitionByRange("name", "age") self.assertEqual(df5.rdd.first(), df2.rdd.first()) self.assertEqual(df5.rdd.take(3), df2.rdd.take(3)) def test_replace(self): schema = StructType([ StructField("name", StringType(), True), StructField("age", IntegerType(), True), StructField("height", DoubleType(), True)]) # replace with int row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() self.assertEqual(row.age, 20) self.assertEqual(row.height, 20.0) # replace with double row = self.spark.createDataFrame( [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first() self.assertEqual(row.age, 82) self.assertEqual(row.height, 82.1) # replace with string row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first() self.assertEqual(row.name, u"Ann") self.assertEqual(row.age, 10) # replace with subset specified by a string of a column name w/ actual change row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first() self.assertEqual(row.age, 20) # replace with subset specified by a string of a column name w/o actual change row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first() self.assertEqual(row.age, 10) # replace with subset specified with one column replaced, another column not in subset # stays unchanged. row = self.spark.createDataFrame( [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first() self.assertEqual(row.name, u'Alice') self.assertEqual(row.age, 20) self.assertEqual(row.height, 10.0) # replace with subset specified but no column will be replaced row = self.spark.createDataFrame( [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first() self.assertEqual(row.name, u'Alice') self.assertEqual(row.age, 10) self.assertEqual(row.height, None) # replace with lists row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first() self.assertTupleEqual(row, (u'Ann', 10, 80.1)) # replace with dict row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() self.assertTupleEqual(row, (u'Alice', 11, 80.1)) # test backward compatibility with dummy value dummy_value = 1 row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first() self.assertTupleEqual(row, (u'Bob', 10, 80.1)) # test dict with mixed numerics row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first() self.assertTupleEqual(row, (u'Alice', -10, 90.5)) # replace with tuples row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() self.assertTupleEqual(row, (u'Bob', 10, 80.1)) # replace multiple columns row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first() self.assertTupleEqual(row, (u'Alice', 20, 90.0)) # test for mixed numerics row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first() self.assertTupleEqual(row, (u'Alice', 20, 90.5)) row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first() self.assertTupleEqual(row, (u'Alice', 20, 90.5)) # replace with boolean row = (self .spark.createDataFrame([(u'Alice', 10, 80.0)], schema) .selectExpr("name = 'Bob'", 'age <= 15') .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) # replace string with None and then drop None rows row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() self.assertEqual(row.count(), 0) # replace with number and None row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first() self.assertTupleEqual(row, (u'Alice', 20, None)) # should fail if subset is not list, tuple or None with self.assertRaises(ValueError): self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first() # should fail if to_replace and value have different length with self.assertRaises(ValueError): self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first() # should fail if when received unexpected type with self.assertRaises(ValueError): from datetime import datetime self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first() # should fail if provided mixed type replacements with self.assertRaises(ValueError): self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first() with self.assertRaises(ValueError): self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() with self.assertRaisesRegexp( TypeError, 'value argument is required when to_replace is not a dictionary.'): self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() def test_with_column_with_existing_name(self): keys = self.df.withColumn("key", self.df.key).select("key").collect() self.assertEqual([r.key for r in keys], list(range(100))) # regression test for SPARK-10417 def test_column_iterator(self): def foo(): for x in self.df.key: break self.assertRaises(TypeError, foo) def test_generic_hints(self): from pyspark.sql import DataFrame df1 = self.spark.range(10e10).toDF("id") df2 = self.spark.range(10e10).toDF("id") self.assertIsInstance(df1.hint("broadcast"), DataFrame) self.assertIsInstance(df1.hint("broadcast", []), DataFrame) # Dummy rules self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame) self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame) plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) # add tests for SPARK-23647 (test more types for hint) def test_extended_hint_types(self): from pyspark.sql import DataFrame df = self.spark.range(10e10).toDF("id") such_a_nice_list = ["itworks1", "itworks2", "itworks3"] hinted_df = df.hint("my awesome hint", 1.2345, "what", such_a_nice_list) logical_plan = hinted_df._jdf.queryExecution().logical() self.assertEqual(1, logical_plan.toString().count("1.2345")) self.assertEqual(1, logical_plan.toString().count("what")) self.assertEqual(3, logical_plan.toString().count("itworks")) def test_sample(self): self.assertRaisesRegexp( TypeError, "should be a bool, float and number", lambda: self.spark.range(1).sample()) self.assertRaises( TypeError, lambda: self.spark.range(1).sample("a")) self.assertRaises( TypeError, lambda: self.spark.range(1).sample(seed="abc")) self.assertRaises( IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0)) def test_toDF_with_schema_string(self): data = [Row(key=i, value=str(i)) for i in range(100)] rdd = self.sc.parallelize(data, 5) df = rdd.toDF("key: int, value: string") self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), data) # different but compatible field types can be used. df = rdd.toDF("key: string, value: string") self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)]) # field names can differ. df = rdd.toDF(" a: int, b: string ") self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), data) # number of fields must match. self.assertRaisesRegexp(Exception, "Length of object", lambda: rdd.toDF("key: int").collect()) # field types mismatch will cause exception at runtime. self.assertRaisesRegexp(Exception, "FloatType can not accept", lambda: rdd.toDF("key: float, value: string").collect()) # flat schema values will be wrapped into row. df = rdd.map(lambda row: row.key).toDF("int") self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) # users can use DataType directly instead of data type string. df = rdd.map(lambda row: row.key).toDF(IntegerType()) self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) def test_join_without_on(self): df1 = self.spark.range(1).toDF("a") df2 = self.spark.range(1).toDF("b") with self.sql_conf({"spark.sql.crossJoin.enabled": False}): self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) with self.sql_conf({"spark.sql.crossJoin.enabled": True}): actual = df1.join(df2, how="inner").collect() expected = [Row(a=0, b=0)] self.assertEqual(actual, expected) # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type")) # Cartesian products require cross join syntax def test_require_cross(self): df1 = self.spark.createDataFrame([(1, "1")], ("key", "value")) df2 = self.spark.createDataFrame([(1, "1")], ("key", "value")) # joins without conditions require cross join syntax self.assertRaises(AnalysisException, lambda: df1.join(df2).collect()) # works with crossJoin self.assertEqual(1, df1.crossJoin(df2).count()) def test_cache(self): spark = self.spark with self.tempView("tab1", "tab2"): spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1") spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2") self.assertFalse(spark.catalog.isCached("tab1")) self.assertFalse(spark.catalog.isCached("tab2")) spark.catalog.cacheTable("tab1") self.assertTrue(spark.catalog.isCached("tab1")) self.assertFalse(spark.catalog.isCached("tab2")) spark.catalog.cacheTable("tab2") spark.catalog.uncacheTable("tab1") self.assertFalse(spark.catalog.isCached("tab1")) self.assertTrue(spark.catalog.isCached("tab2")) spark.catalog.clearCache() self.assertFalse(spark.catalog.isCached("tab1")) self.assertFalse(spark.catalog.isCached("tab2")) self.assertRaisesRegexp( AnalysisException, "does_not_exist", lambda: spark.catalog.isCached("does_not_exist")) self.assertRaisesRegexp( AnalysisException, "does_not_exist", lambda: spark.catalog.cacheTable("does_not_exist")) self.assertRaisesRegexp( AnalysisException, "does_not_exist", lambda: spark.catalog.uncacheTable("does_not_exist")) def _to_pandas(self): from datetime import datetime, date schema = StructType().add("a", IntegerType()).add("b", StringType())\ .add("c", BooleanType()).add("d", FloatType())\ .add("dt", DateType()).add("ts", TimestampType()) data = [ (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), (2, "foo", True, 5.0, None, None), (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)), (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)), ] df = self.spark.createDataFrame(data, schema) return df.toPandas() @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_to_pandas(self): import numpy as np pdf = self._to_pandas() types = pdf.dtypes self.assertEquals(types[0], np.int32) self.assertEquals(types[1], np.object) self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') @unittest.skipIf(have_pandas, "Required Pandas was found.") def test_to_pandas_required_pandas_not_found(self): with QuietTest(self.sc): with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self._to_pandas() @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_to_pandas_avoid_astype(self): import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ .add("c", IntegerType()) data = [(1, "foo", 16777220), (None, "bar", None)] df = self.spark.createDataFrame(data, schema) types = df.toPandas().dtypes self.assertEquals(types[0], np.float64) # doesn't convert to np.int32 due to NaN value. self.assertEquals(types[1], np.object) self.assertEquals(types[2], np.float64) def test_create_dataframe_from_array_of_long(self): import array data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))] df = self.spark.createDataFrame(data) self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_create_dataframe_from_pandas_with_timestamp(self): import pandas as pd from datetime import datetime pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], "d": [pd.Timestamp.now().date()]}, columns=["d", "ts"]) # test types are inferred correctly without specifying schema df = self.spark.createDataFrame(pdf) self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) # test with schema will accept pdf as input df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp") self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) @unittest.skipIf(have_pandas, "Required Pandas was found.") def test_create_dataframe_required_pandas_not_found(self): with QuietTest(self.sc): with self.assertRaisesRegexp( ImportError, "(Pandas >= .* must be installed|No module named '?pandas'?)"): import pandas as pd from datetime import datetime pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], "d": [pd.Timestamp.now().date()]}) self.spark.createDataFrame(pdf) # Regression test for SPARK-23360 @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_create_dateframe_from_pandas_with_dst(self): import pandas as pd from datetime import datetime pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]}) df = self.spark.createDataFrame(pdf) self.assertPandasEqual(pdf, df.toPandas()) orig_env_tz = os.environ.get('TZ', None) try: tz = 'America/Los_Angeles' os.environ['TZ'] = tz time.tzset() with self.sql_conf({'spark.sql.session.timeZone': tz}): df = self.spark.createDataFrame(pdf) self.assertPandasEqual(pdf, df.toPandas()) finally: del os.environ['TZ'] if orig_env_tz is not None: os.environ['TZ'] = orig_env_tz time.tzset() def test_repr_behaviors(self): import re pattern = re.compile(r'^ *\|', re.MULTILINE) df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) # test when eager evaluation is enabled and _repr_html_ will not be called with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): expected1 = """+-----+-----+ || key|value| |+-----+-----+ || 1| 1| ||22222|22222| |+-----+-----+ |""" self.assertEquals(re.sub(pattern, '', expected1), df.__repr__()) with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): expected2 = """+---+-----+ ||key|value| |+---+-----+ || 1| 1| ||222| 222| |+---+-----+ |""" self.assertEquals(re.sub(pattern, '', expected2), df.__repr__()) with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): expected3 = """+---+-----+ ||key|value| |+---+-----+ || 1| 1| |+---+-----+ |only showing top 1 row |""" self.assertEquals(re.sub(pattern, '', expected3), df.__repr__()) # test when eager evaluation is enabled and _repr_html_ will be called with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): expected1 = """ | | | |
keyvalue
11
2222222222
|""" self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_()) with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): expected2 = """ | | | |
keyvalue
11
222222
|""" self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_()) with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): expected3 = """ | | |
keyvalue
11
|only showing top 1 row |""" self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) # test when eager evaluation is disabled and _repr_html_ will be called with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): expected = "DataFrame[key: bigint, value: string]" self.assertEquals(None, df._repr_html_()) self.assertEquals(expected, df.__repr__()) with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): self.assertEquals(None, df._repr_html_()) self.assertEquals(expected, df.__repr__()) with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): self.assertEquals(None, df._repr_html_()) self.assertEquals(expected, df.__repr__()) class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is # static and immutable. This can't be set or unset, for example, via `spark.conf`. @classmethod def setUpClass(cls): import glob from pyspark.find_spark_home import _find_spark_home SPARK_HOME = _find_spark_home() filename_pattern = ( "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" "TestQueryExecutionListener.class") cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) if cls.has_listener: # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. cls.spark = SparkSession.builder \ .master("local[4]") \ .appName(cls.__name__) \ .config( "spark.sql.queryExecutionListeners", "org.apache.spark.sql.TestQueryExecutionListener") \ .getOrCreate() def setUp(self): if not self.has_listener: raise self.skipTest( "'org.apache.spark.sql.TestQueryExecutionListener' is not " "available. Will skip the related tests.") @classmethod def tearDownClass(cls): if hasattr(cls, "spark"): cls.spark.stop() def tearDown(self): self.spark._jvm.OnSuccessCall.clear() def test_query_execution_listener_on_collect(self): self.assertFalse( self.spark._jvm.OnSuccessCall.isCalled(), "The callback from the query execution listener should not be called before 'collect'") self.spark.sql("SELECT * FROM range(1)").collect() self.assertTrue( self.spark._jvm.OnSuccessCall.isCalled(), "The callback from the query execution listener should be called after 'collect'") @unittest.skipIf( not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message) def test_query_execution_listener_on_collect_with_arrow(self): with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): self.assertFalse( self.spark._jvm.OnSuccessCall.isCalled(), "The callback from the query execution listener should not be " "called before 'toPandas'") self.spark.sql("SELECT * FROM range(1)").toPandas() self.assertTrue( self.spark._jvm.OnSuccessCall.isCalled(), "The callback from the query execution listener should be called after 'toPandas'") if __name__ == "__main__": from pyspark.sql.tests.test_dataframe import * try: import xmlrunner testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: testRunner = None unittest.main(testRunner=testRunner, verbosity=2)