[SPARK-15008][ML][PYSPARK] Add integration test for OneVsRest

## What changes were proposed in this pull request?

1. Add `_transfer_param_map_to/from_java` for OneVsRest;

2. Add `_compare_params` in ml/tests.py to help compare params.

3. Add `test_onevsrest` as the integration test for OneVsRest.

## How was this patch tested?

Python unit test.

Author: yinxusen <yinxusen@gmail.com>

Closes #12875 from yinxusen/SPARK-15008.
This commit is contained in:
yinxusen 2016-05-27 13:18:29 -07:00 committed by Joseph K. Bradley
parent a3550e3747
commit 130b8d07b8

View file

@ -747,12 +747,32 @@ class PersistenceTest(SparkSessionTestCase):
except OSError:
pass
def _compare_params(self, m1, m2, param):
"""
Compare 2 ML Params instances for the given param, and assert both have the same param value
and parent. The param must be a parameter of m1.
"""
# Prevent key not found error in case of some param in neither paramMap nor defaultParamMap.
if m1.isDefined(param):
paramValue1 = m1.getOrDefault(param)
paramValue2 = m2.getOrDefault(m2.getParam(param.name))
if isinstance(paramValue1, Params):
self._compare_pipelines(paramValue1, paramValue2)
else:
self.assertEqual(paramValue1, paramValue2) # for general types param
# Assert parents are equal
self.assertEqual(param.parent, m2.getParam(param.name).parent)
else:
# If m1 is not defined param, then m2 should not, too. See SPARK-14931.
self.assertFalse(m2.isDefined(m2.getParam(param.name)))
def _compare_pipelines(self, m1, m2):
"""
Compare 2 ML types, asserting that they are equivalent.
This currently supports:
- basic types
- Pipeline, PipelineModel
- OneVsRest, OneVsRestModel
This checks:
- uid
- type
@ -763,8 +783,7 @@ class PersistenceTest(SparkSessionTestCase):
if isinstance(m1, JavaParams):
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
self.assertEqual(p.parent, m2.getParam(p.name).parent)
self._compare_params(m1, m2, p)
elif isinstance(m1, Pipeline):
self.assertEqual(len(m1.getStages()), len(m2.getStages()))
for s1, s2 in zip(m1.getStages(), m2.getStages()):
@ -773,6 +792,13 @@ class PersistenceTest(SparkSessionTestCase):
self.assertEqual(len(m1.stages), len(m2.stages))
for s1, s2 in zip(m1.stages, m2.stages):
self._compare_pipelines(s1, s2)
elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel):
for p in m1.params:
self._compare_params(m1, m2, p)
if isinstance(m1, OneVsRestModel):
self.assertEqual(len(m1.models), len(m2.models))
for x, y in zip(m1.models, m2.models):
self._compare_pipelines(x, y)
else:
raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
@ -833,6 +859,24 @@ class PersistenceTest(SparkSessionTestCase):
except OSError:
pass
def test_onevsrest(self):
temp_path = tempfile.mkdtemp()
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
(2.0, Vectors.dense(0.5, 0.5))] * 10,
["label", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr)
model = ovr.fit(df)
ovrPath = temp_path + "/ovr"
ovr.save(ovrPath)
loadedOvr = OneVsRest.load(ovrPath)
self._compare_pipelines(ovr, loadedOvr)
modelPath = temp_path + "/ovrModel"
model.save(modelPath)
loadedModel = OneVsRestModel.load(modelPath)
self._compare_pipelines(model, loadedModel)
def test_decisiontree_classifier(self):
dt = DecisionTreeClassifier(maxDepth=1)
path = tempfile.mkdtemp()
@ -1054,27 +1098,6 @@ class OneVsRestTests(SparkSessionTestCase):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])
def test_save_load(self):
temp_path = tempfile.mkdtemp()
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
(2.0, Vectors.dense(0.5, 0.5))],
["label", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr)
model = ovr.fit(df)
ovrPath = temp_path + "/ovr"
ovr.save(ovrPath)
loadedOvr = OneVsRest.load(ovrPath)
self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol())
self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol())
self.assertEqual(loadedOvr.getClassifier().uid, ovr.getClassifier().uid)
modelPath = temp_path + "/ovrModel"
model.save(modelPath)
loadedModel = OneVsRestModel.load(modelPath)
for m, n in zip(model.models, loadedModel.models):
self.assertEqual(m.uid, n.uid)
class HashingTFTest(SparkSessionTestCase):