# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import numpy as np from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable from pyspark.ml.wrapper import _java2py from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import DoubleType from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase def check_params(test_self, py_stage, check_params_exist=True): """ Checks common requirements for Params.params: - set of params exist in Java and Python and are ordered by names - param parent has the same UID as the object's UID - default param value from Java matches value in Python - optionally check if all params from Java also exist in Python """ py_stage_str = "%s %s" % (type(py_stage), py_stage) if not hasattr(py_stage, "_to_java"): return java_stage = py_stage._to_java() if java_stage is None: return test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str) if check_params_exist: param_names = [p.name for p in py_stage.params] java_params = list(java_stage.params()) java_param_names = [jp.name() for jp in java_params] test_self.assertEqual( param_names, sorted(java_param_names), "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" % (py_stage_str, java_param_names, param_names)) for p in py_stage.params: test_self.assertEqual(p.parent, py_stage.uid) java_param = java_stage.getParam(p.name) py_has_default = py_stage.hasDefault(p) java_has_default = java_stage.hasDefault(java_param) test_self.assertEqual(py_has_default, java_has_default, "Default value mismatch of param %s for Params %s" % (p.name, str(py_stage))) if py_has_default: if p.name == "seed": continue # Random seeds between Spark and PySpark are different java_default = _java2py(test_self.sc, java_stage.clear(java_param).getOrDefault(java_param)) py_stage.clear(p) py_default = py_stage.getOrDefault(p) # equality test for NaN is always False if isinstance(java_default, float) and np.isnan(java_default): java_default = "NaN" py_default = "NaN" if np.isnan(py_default) else "not NaN" test_self.assertEqual( java_default, py_default, "Java default %s != python default %s of param %s for Params %s" % (str(java_default), str(py_default), p.name, str(py_stage))) class SparkSessionTestCase(PySparkTestCase): @classmethod def setUpClass(cls): PySparkTestCase.setUpClass() cls.spark = SparkSession(cls.sc) @classmethod def tearDownClass(cls): PySparkTestCase.tearDownClass() cls.spark.stop() class MockDataset(DataFrame): def __init__(self): self.index = 0 class HasFake(Params): def __init__(self): super(HasFake, self).__init__() self.fake = Param(self, "fake", "fake param") def getFake(self): return self.getOrDefault(self.fake) class MockTransformer(Transformer, HasFake): def __init__(self): super(MockTransformer, self).__init__() self.dataset_index = None def _transform(self, dataset): self.dataset_index = dataset.index dataset.index += 1 return dataset class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): shift = Param(Params._dummy(), "shift", "The amount by which to shift " + "data in a DataFrame", typeConverter=TypeConverters.toFloat) def __init__(self, shiftVal=1): super(MockUnaryTransformer, self).__init__() self._setDefault(shift=1) self._set(shift=shiftVal) def getShift(self): return self.getOrDefault(self.shift) def setShift(self, shift): self._set(shift=shift) def createTransformFunc(self): shiftVal = self.getShift() return lambda x: x + shiftVal def outputDataType(self): return DoubleType() def validateInputType(self, inputType): if inputType != DoubleType(): raise TypeError("Bad input type: {}. ".format(inputType) + "Requires Double.") class MockEstimator(Estimator, HasFake): def __init__(self): super(MockEstimator, self).__init__() self.dataset_index = None def _fit(self, dataset): self.dataset_index = dataset.index model = MockModel() self._copyValues(model) return model class MockModel(MockTransformer, Model, HasFake): pass