diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index 9da983667b..f99161c881 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -980,14 +980,14 @@ class DenseMatrix(Matrix): return self.values[i + j * self.numRows] def __eq__(self, other): - if (not isinstance(other, DenseMatrix) or - self.numRows != other.numRows or - self.numCols != other.numCols): + if (self.numRows != other.numRows or self.numCols != other.numCols): return False + if isinstance(other, SparseMatrix): + return np.all(self.toArray() == other.toArray()) self_values = np.ravel(self.toArray(), order='F') other_values = np.ravel(other.toArray(), order='F') - return all(self_values == other_values) + return np.all(self_values == other_values) class SparseMatrix(Matrix): diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index 995bc35e4c..0c25e2b818 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -112,11 +112,17 @@ class VectorTests(MLlibTestCase): v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + dm1 = DenseMatrix(2, 2, [2, 0, 0, 0]) + sm1 = SparseMatrix(2, 2, [0, 2, 3], [0], [2]) self.assertEqual(v1, v2) self.assertEqual(v1, v3) self.assertFalse(v2 == v4) self.assertFalse(v1 == v5) self.assertFalse(v1 == v6) + # this is done as Dense and Sparse matrices can be semantically + # equal while still implementing a different __eq__ method + self.assertEqual(dm1, sm1) + self.assertEqual(sm1, dm1) def test_equals(self): indices = [1, 2, 4] diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 94a3e2af4d..df411d7990 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1135,14 +1135,14 @@ class DenseMatrix(Matrix): return self.values[i + j * self.numRows] def __eq__(self, other): - if (not isinstance(other, DenseMatrix) or - self.numRows != other.numRows or - self.numCols != other.numCols): + if (self.numRows != other.numRows or self.numCols != other.numCols): return False + if isinstance(other, SparseMatrix): + return np.all(self.toArray() == other.toArray()) self_values = np.ravel(self.toArray(), order='F') other_values = np.ravel(other.toArray(), order='F') - return all(self_values == other_values) + return np.all(self_values == other_values) class SparseMatrix(Matrix): diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index f26e28d174..703aed2fe1 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -115,11 +115,17 @@ class VectorTests(MLlibTestCase): v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + dm1 = DenseMatrix(2, 2, [2, 0, 0, 0]) + sm1 = SparseMatrix(2, 2, [0, 2, 3], [0], [2]) self.assertEqual(v1, v2) self.assertEqual(v1, v3) self.assertFalse(v2 == v4) self.assertFalse(v1 == v5) self.assertFalse(v1 == v6) + # this is done as Dense and Sparse matrices can be semantically + # equal while still implementing a different __eq__ method + self.assertEqual(dm1, sm1) + self.assertEqual(sm1, dm1) def test_equals(self): indices = [1, 2, 4]