# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from pyspark.sql import Row from pyspark.testing.sqlutils import ReusedSQLTestCase class GroupTests(ReusedSQLTestCase): def test_aggregator(self): df = self.df g = df.groupBy() self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) from pyspark.sql import functions self.assertEqual((0, u'99'), tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) if __name__ == "__main__": import unittest from pyspark.sql.tests.test_group import * # noqa: F401 try: import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None unittest.main(testRunner=testRunner, verbosity=2)