Revert "[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None"
This reverts commit fcd013cf70
.
Author: Yin Huai <yhuai@databricks.com>
Closes #10632 from yhuai/pythonStyle.
This commit is contained in:
parent
b673852037
commit
e5cde7ab11
|
@ -346,7 +346,7 @@ class GaussianMixture(object):
|
|||
if initialModel.k != k:
|
||||
raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s"
|
||||
% (initialModel.k, k))
|
||||
initialModelWeights = list(initialModel.weights)
|
||||
initialModelWeights = initialModel.weights
|
||||
initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
|
||||
initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
|
||||
java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
|
||||
|
|
|
@ -475,18 +475,6 @@ class ListTests(MLlibTestCase):
|
|||
for c1, c2 in zip(clusters1.weights, clusters2.weights):
|
||||
self.assertEqual(round(c1, 7), round(c2, 7))
|
||||
|
||||
def test_gmm_with_initial_model(self):
|
||||
from pyspark.mllib.clustering import GaussianMixture
|
||||
data = self.sc.parallelize([
|
||||
(-10, -5), (-9, -4), (10, 5), (9, 4)
|
||||
])
|
||||
|
||||
gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
|
||||
maxIterations=10, seed=63)
|
||||
gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
|
||||
maxIterations=10, seed=63, initialModel=gmm1)
|
||||
self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)
|
||||
|
||||
def test_classification(self):
|
||||
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
|
||||
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\
|
||||
|
|
Loading…
Reference in a new issue