2018-11-18 03:02:15 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2020-12-03 19:35:50 -05:00
|
|
|
from pyspark import keyword_only
|
2018-11-18 03:02:15 -05:00
|
|
|
from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer
|
2020-12-03 19:35:50 -05:00
|
|
|
from pyspark.ml.evaluation import Evaluator
|
2018-11-18 03:02:15 -05:00
|
|
|
from pyspark.ml.param import Param, Params, TypeConverters
|
2020-12-03 19:35:50 -05:00
|
|
|
from pyspark.ml.param.shared import HasMaxIter, HasRegParam
|
|
|
|
from pyspark.ml.classification import Classifier, ClassificationModel
|
2018-11-18 03:02:15 -05:00
|
|
|
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
|
2020-09-24 01:15:36 -04:00
|
|
|
from pyspark.ml.wrapper import _java2py # type: ignore
|
2018-11-18 03:02:15 -05:00
|
|
|
from pyspark.sql import DataFrame, SparkSession
|
|
|
|
from pyspark.sql.types import DoubleType
|
|
|
|
from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase
|
|
|
|
|
|
|
|
|
|
|
|
def check_params(test_self, py_stage, check_params_exist=True):
|
|
|
|
"""
|
2020-11-15 20:21:50 -05:00
|
|
|
Checks common requirements for :py:class:`PySpark.ml.Params.params`:
|
|
|
|
|
2018-11-18 03:02:15 -05:00
|
|
|
- set of params exist in Java and Python and are ordered by names
|
|
|
|
- param parent has the same UID as the object's UID
|
|
|
|
- default param value from Java matches value in Python
|
|
|
|
- optionally check if all params from Java also exist in Python
|
|
|
|
"""
|
|
|
|
py_stage_str = "%s %s" % (type(py_stage), py_stage)
|
|
|
|
if not hasattr(py_stage, "_to_java"):
|
|
|
|
return
|
|
|
|
java_stage = py_stage._to_java()
|
|
|
|
if java_stage is None:
|
|
|
|
return
|
|
|
|
test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str)
|
|
|
|
if check_params_exist:
|
|
|
|
param_names = [p.name for p in py_stage.params]
|
|
|
|
java_params = list(java_stage.params())
|
|
|
|
java_param_names = [jp.name() for jp in java_params]
|
|
|
|
test_self.assertEqual(
|
|
|
|
param_names, sorted(java_param_names),
|
|
|
|
"Param list in Python does not match Java for %s:\nJava = %s\nPython = %s"
|
|
|
|
% (py_stage_str, java_param_names, param_names))
|
|
|
|
for p in py_stage.params:
|
|
|
|
test_self.assertEqual(p.parent, py_stage.uid)
|
|
|
|
java_param = java_stage.getParam(p.name)
|
|
|
|
py_has_default = py_stage.hasDefault(p)
|
|
|
|
java_has_default = java_stage.hasDefault(java_param)
|
|
|
|
test_self.assertEqual(py_has_default, java_has_default,
|
|
|
|
"Default value mismatch of param %s for Params %s"
|
|
|
|
% (p.name, str(py_stage)))
|
|
|
|
if py_has_default:
|
|
|
|
if p.name == "seed":
|
|
|
|
continue # Random seeds between Spark and PySpark are different
|
|
|
|
java_default = _java2py(test_self.sc,
|
|
|
|
java_stage.clear(java_param).getOrDefault(java_param))
|
2019-10-17 20:02:31 -04:00
|
|
|
py_stage.clear(p)
|
2018-11-18 03:02:15 -05:00
|
|
|
py_default = py_stage.getOrDefault(p)
|
|
|
|
# equality test for NaN is always False
|
|
|
|
if isinstance(java_default, float) and np.isnan(java_default):
|
|
|
|
java_default = "NaN"
|
|
|
|
py_default = "NaN" if np.isnan(py_default) else "not NaN"
|
|
|
|
test_self.assertEqual(
|
|
|
|
java_default, py_default,
|
|
|
|
"Java default %s != python default %s of param %s for Params %s"
|
|
|
|
% (str(java_default), str(py_default), p.name, str(py_stage)))
|
|
|
|
|
|
|
|
|
|
|
|
class SparkSessionTestCase(PySparkTestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
PySparkTestCase.setUpClass()
|
|
|
|
cls.spark = SparkSession(cls.sc)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
PySparkTestCase.tearDownClass()
|
|
|
|
cls.spark.stop()
|
|
|
|
|
|
|
|
|
|
|
|
class MockDataset(DataFrame):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.index = 0
|
|
|
|
|
|
|
|
|
|
|
|
class HasFake(Params):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super(HasFake, self).__init__()
|
|
|
|
self.fake = Param(self, "fake", "fake param")
|
|
|
|
|
|
|
|
def getFake(self):
|
|
|
|
return self.getOrDefault(self.fake)
|
|
|
|
|
|
|
|
|
|
|
|
class MockTransformer(Transformer, HasFake):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super(MockTransformer, self).__init__()
|
|
|
|
self.dataset_index = None
|
|
|
|
|
|
|
|
def _transform(self, dataset):
|
|
|
|
self.dataset_index = dataset.index
|
|
|
|
dataset.index += 1
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
|
|
|
class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):
|
|
|
|
|
2020-09-24 01:15:36 -04:00
|
|
|
shift = Param(Params._dummy(), # type: ignore
|
|
|
|
"shift", "The amount by which to shift " +
|
2018-11-18 03:02:15 -05:00
|
|
|
"data in a DataFrame",
|
|
|
|
typeConverter=TypeConverters.toFloat)
|
|
|
|
|
|
|
|
def __init__(self, shiftVal=1):
|
|
|
|
super(MockUnaryTransformer, self).__init__()
|
|
|
|
self._setDefault(shift=1)
|
|
|
|
self._set(shift=shiftVal)
|
|
|
|
|
|
|
|
def getShift(self):
|
|
|
|
return self.getOrDefault(self.shift)
|
|
|
|
|
|
|
|
def setShift(self, shift):
|
|
|
|
self._set(shift=shift)
|
|
|
|
|
|
|
|
def createTransformFunc(self):
|
|
|
|
shiftVal = self.getShift()
|
|
|
|
return lambda x: x + shiftVal
|
|
|
|
|
|
|
|
def outputDataType(self):
|
|
|
|
return DoubleType()
|
|
|
|
|
|
|
|
def validateInputType(self, inputType):
|
|
|
|
if inputType != DoubleType():
|
|
|
|
raise TypeError("Bad input type: {}. ".format(inputType) +
|
|
|
|
"Requires Double.")
|
|
|
|
|
|
|
|
|
|
|
|
class MockEstimator(Estimator, HasFake):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super(MockEstimator, self).__init__()
|
|
|
|
self.dataset_index = None
|
|
|
|
|
|
|
|
def _fit(self, dataset):
|
|
|
|
self.dataset_index = dataset.index
|
|
|
|
model = MockModel()
|
|
|
|
self._copyValues(model)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
class MockModel(MockTransformer, Model, HasFake):
|
|
|
|
pass
|
2020-12-03 19:35:50 -05:00
|
|
|
|
|
|
|
|
|
|
|
class _DummyLogisticRegressionParams(HasMaxIter, HasRegParam):
|
|
|
|
def setMaxIter(self, value):
|
|
|
|
return self._set(maxIter=value)
|
|
|
|
|
|
|
|
def setRegParam(self, value):
|
|
|
|
return self._set(regParam=value)
|
|
|
|
|
|
|
|
|
|
|
|
# This is a dummy LogisticRegression used in test for python backend estimator/model
|
|
|
|
class DummyLogisticRegression(Classifier, _DummyLogisticRegressionParams,
|
|
|
|
DefaultParamsReadable, DefaultParamsWritable):
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
|
|
maxIter=100, regParam=0.0, rawPredictionCol="rawPrediction"):
|
|
|
|
super(DummyLogisticRegression, self).__init__()
|
|
|
|
kwargs = self._input_kwargs
|
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
|
|
maxIter=100, regParam=0.0, rawPredictionCol="rawPrediction"):
|
|
|
|
kwargs = self._input_kwargs
|
|
|
|
self._set(**kwargs)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def _fit(self, dataset):
|
|
|
|
# Do nothing but create a dummy model
|
|
|
|
return self._copyValues(DummyLogisticRegressionModel())
|
|
|
|
|
|
|
|
|
|
|
|
class DummyLogisticRegressionModel(ClassificationModel, _DummyLogisticRegressionParams,
|
|
|
|
DefaultParamsReadable, DefaultParamsWritable):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super(DummyLogisticRegressionModel, self).__init__()
|
|
|
|
|
|
|
|
def _transform(self, dataset):
|
|
|
|
# A dummy transform impl which always predict label 1
|
|
|
|
from pyspark.sql.functions import array, lit
|
|
|
|
from pyspark.ml.functions import array_to_vector
|
|
|
|
rawPredCol = self.getRawPredictionCol()
|
|
|
|
if rawPredCol:
|
|
|
|
dataset = dataset.withColumn(
|
|
|
|
rawPredCol, array_to_vector(array(lit(-100.0), lit(100.0))))
|
|
|
|
predCol = self.getPredictionCol()
|
|
|
|
if predCol:
|
|
|
|
dataset = dataset.withColumn(predCol, lit(1.0))
|
|
|
|
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
@property
|
|
|
|
def numClasses(self):
|
|
|
|
# a dummy implementation for test.
|
|
|
|
return 2
|
|
|
|
|
|
|
|
@property
|
|
|
|
def intercept(self):
|
|
|
|
# a dummy implementation for test.
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
# This class only used in test. The following methods/properties are not used in tests.
|
|
|
|
|
|
|
|
@property
|
|
|
|
def coefficients(self):
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
def predictRaw(self, value):
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
def numFeatures(self):
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
def predict(self, value):
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
|
|
class DummyEvaluator(Evaluator, DefaultParamsReadable, DefaultParamsWritable):
|
|
|
|
|
|
|
|
def _evaluate(self, dataset):
|
|
|
|
# a dummy implementation for test.
|
|
|
|
return 1.0
|