[SPARK-15008][ML][PYSPARK] Add integration test for OneVsRest
## What changes were proposed in this pull request? 1. Add `_transfer_param_map_to/from_java` for OneVsRest; 2. Add `_compare_params` in ml/tests.py to help compare params. 3. Add `test_onevsrest` as the integration test for OneVsRest. ## How was this patch tested? Python unit test. Author: yinxusen <yinxusen@gmail.com> Closes #12875 from yinxusen/SPARK-15008.
This commit is contained in:
parent
a3550e3747
commit
130b8d07b8
|
@ -747,12 +747,32 @@ class PersistenceTest(SparkSessionTestCase):
|
|||
except OSError:
|
||||
pass
|
||||
|
||||
def _compare_params(self, m1, m2, param):
|
||||
"""
|
||||
Compare 2 ML Params instances for the given param, and assert both have the same param value
|
||||
and parent. The param must be a parameter of m1.
|
||||
"""
|
||||
# Prevent key not found error in case of some param in neither paramMap nor defaultParamMap.
|
||||
if m1.isDefined(param):
|
||||
paramValue1 = m1.getOrDefault(param)
|
||||
paramValue2 = m2.getOrDefault(m2.getParam(param.name))
|
||||
if isinstance(paramValue1, Params):
|
||||
self._compare_pipelines(paramValue1, paramValue2)
|
||||
else:
|
||||
self.assertEqual(paramValue1, paramValue2) # for general types param
|
||||
# Assert parents are equal
|
||||
self.assertEqual(param.parent, m2.getParam(param.name).parent)
|
||||
else:
|
||||
# If m1 is not defined param, then m2 should not, too. See SPARK-14931.
|
||||
self.assertFalse(m2.isDefined(m2.getParam(param.name)))
|
||||
|
||||
def _compare_pipelines(self, m1, m2):
|
||||
"""
|
||||
Compare 2 ML types, asserting that they are equivalent.
|
||||
This currently supports:
|
||||
- basic types
|
||||
- Pipeline, PipelineModel
|
||||
- OneVsRest, OneVsRestModel
|
||||
This checks:
|
||||
- uid
|
||||
- type
|
||||
|
@ -763,8 +783,7 @@ class PersistenceTest(SparkSessionTestCase):
|
|||
if isinstance(m1, JavaParams):
|
||||
self.assertEqual(len(m1.params), len(m2.params))
|
||||
for p in m1.params:
|
||||
self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
|
||||
self.assertEqual(p.parent, m2.getParam(p.name).parent)
|
||||
self._compare_params(m1, m2, p)
|
||||
elif isinstance(m1, Pipeline):
|
||||
self.assertEqual(len(m1.getStages()), len(m2.getStages()))
|
||||
for s1, s2 in zip(m1.getStages(), m2.getStages()):
|
||||
|
@ -773,6 +792,13 @@ class PersistenceTest(SparkSessionTestCase):
|
|||
self.assertEqual(len(m1.stages), len(m2.stages))
|
||||
for s1, s2 in zip(m1.stages, m2.stages):
|
||||
self._compare_pipelines(s1, s2)
|
||||
elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel):
|
||||
for p in m1.params:
|
||||
self._compare_params(m1, m2, p)
|
||||
if isinstance(m1, OneVsRestModel):
|
||||
self.assertEqual(len(m1.models), len(m2.models))
|
||||
for x, y in zip(m1.models, m2.models):
|
||||
self._compare_pipelines(x, y)
|
||||
else:
|
||||
raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
|
||||
|
||||
|
@ -833,6 +859,24 @@ class PersistenceTest(SparkSessionTestCase):
|
|||
except OSError:
|
||||
pass
|
||||
|
||||
def test_onevsrest(self):
|
||||
temp_path = tempfile.mkdtemp()
|
||||
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
|
||||
(1.0, Vectors.sparse(2, [], [])),
|
||||
(2.0, Vectors.dense(0.5, 0.5))] * 10,
|
||||
["label", "features"])
|
||||
lr = LogisticRegression(maxIter=5, regParam=0.01)
|
||||
ovr = OneVsRest(classifier=lr)
|
||||
model = ovr.fit(df)
|
||||
ovrPath = temp_path + "/ovr"
|
||||
ovr.save(ovrPath)
|
||||
loadedOvr = OneVsRest.load(ovrPath)
|
||||
self._compare_pipelines(ovr, loadedOvr)
|
||||
modelPath = temp_path + "/ovrModel"
|
||||
model.save(modelPath)
|
||||
loadedModel = OneVsRestModel.load(modelPath)
|
||||
self._compare_pipelines(model, loadedModel)
|
||||
|
||||
def test_decisiontree_classifier(self):
|
||||
dt = DecisionTreeClassifier(maxDepth=1)
|
||||
path = tempfile.mkdtemp()
|
||||
|
@ -1054,27 +1098,6 @@ class OneVsRestTests(SparkSessionTestCase):
|
|||
output = model.transform(df)
|
||||
self.assertEqual(output.columns, ["label", "features", "prediction"])
|
||||
|
||||
def test_save_load(self):
|
||||
temp_path = tempfile.mkdtemp()
|
||||
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
|
||||
(1.0, Vectors.sparse(2, [], [])),
|
||||
(2.0, Vectors.dense(0.5, 0.5))],
|
||||
["label", "features"])
|
||||
lr = LogisticRegression(maxIter=5, regParam=0.01)
|
||||
ovr = OneVsRest(classifier=lr)
|
||||
model = ovr.fit(df)
|
||||
ovrPath = temp_path + "/ovr"
|
||||
ovr.save(ovrPath)
|
||||
loadedOvr = OneVsRest.load(ovrPath)
|
||||
self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol())
|
||||
self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol())
|
||||
self.assertEqual(loadedOvr.getClassifier().uid, ovr.getClassifier().uid)
|
||||
modelPath = temp_path + "/ovrModel"
|
||||
model.save(modelPath)
|
||||
loadedModel = OneVsRestModel.load(modelPath)
|
||||
for m, n in zip(model.models, loadedModel.models):
|
||||
self.assertEqual(m.uid, n.uid)
|
||||
|
||||
|
||||
class HashingTFTest(SparkSessionTestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue