[SPARK-9792] Make DenseMatrix equality semantical

Before, you could have this code

```
A = SparseMatrix(2, 2, [0, 2, 3], [0], [2])
B = DenseMatrix(2, 2, [2, 0, 0, 0])

B == A  # False
A == B  # True
```

The second would be `True` as `SparseMatrix` already checks for semantic
equality. This commit changes `DenseMatrix` so that equality is
semantical as well.

## What changes were proposed in this pull request?

Better semantic equality for DenseMatrix

## How was this patch tested?

Unit tests were added, plus manual testing. Note that the code falls back to the old behavior when `other` is not a SparseMatrix.

Closes #17968 from gglanzani/SPARK-9792.

Authored-by: Giovanni Lanzani <giovanni@lanzani.nl>
Signed-off-by: Holden Karau <holden@pigscanfly.ca>
This commit is contained in:
Giovanni Lanzani 2019-04-01 09:30:33 -07:00 committed by Holden Karau
parent 5888b15d9c
commit 92530c7db1
4 changed files with 20 additions and 8 deletions

View file

@ -980,14 +980,14 @@ class DenseMatrix(Matrix):
return self.values[i + j * self.numRows] return self.values[i + j * self.numRows]
def __eq__(self, other): def __eq__(self, other):
if (not isinstance(other, DenseMatrix) or if (self.numRows != other.numRows or self.numCols != other.numCols):
self.numRows != other.numRows or
self.numCols != other.numCols):
return False return False
if isinstance(other, SparseMatrix):
return np.all(self.toArray() == other.toArray())
self_values = np.ravel(self.toArray(), order='F') self_values = np.ravel(self.toArray(), order='F')
other_values = np.ravel(other.toArray(), order='F') other_values = np.ravel(other.toArray(), order='F')
return all(self_values == other_values) return np.all(self_values == other_values)
class SparseMatrix(Matrix): class SparseMatrix(Matrix):

View file

@ -112,11 +112,17 @@ class VectorTests(MLlibTestCase):
v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
dm1 = DenseMatrix(2, 2, [2, 0, 0, 0])
sm1 = SparseMatrix(2, 2, [0, 2, 3], [0], [2])
self.assertEqual(v1, v2) self.assertEqual(v1, v2)
self.assertEqual(v1, v3) self.assertEqual(v1, v3)
self.assertFalse(v2 == v4) self.assertFalse(v2 == v4)
self.assertFalse(v1 == v5) self.assertFalse(v1 == v5)
self.assertFalse(v1 == v6) self.assertFalse(v1 == v6)
# this is done as Dense and Sparse matrices can be semantically
# equal while still implementing a different __eq__ method
self.assertEqual(dm1, sm1)
self.assertEqual(sm1, dm1)
def test_equals(self): def test_equals(self):
indices = [1, 2, 4] indices = [1, 2, 4]

View file

@ -1135,14 +1135,14 @@ class DenseMatrix(Matrix):
return self.values[i + j * self.numRows] return self.values[i + j * self.numRows]
def __eq__(self, other): def __eq__(self, other):
if (not isinstance(other, DenseMatrix) or if (self.numRows != other.numRows or self.numCols != other.numCols):
self.numRows != other.numRows or
self.numCols != other.numCols):
return False return False
if isinstance(other, SparseMatrix):
return np.all(self.toArray() == other.toArray())
self_values = np.ravel(self.toArray(), order='F') self_values = np.ravel(self.toArray(), order='F')
other_values = np.ravel(other.toArray(), order='F') other_values = np.ravel(other.toArray(), order='F')
return all(self_values == other_values) return np.all(self_values == other_values)
class SparseMatrix(Matrix): class SparseMatrix(Matrix):

View file

@ -115,11 +115,17 @@ class VectorTests(MLlibTestCase):
v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
dm1 = DenseMatrix(2, 2, [2, 0, 0, 0])
sm1 = SparseMatrix(2, 2, [0, 2, 3], [0], [2])
self.assertEqual(v1, v2) self.assertEqual(v1, v2)
self.assertEqual(v1, v3) self.assertEqual(v1, v3)
self.assertFalse(v2 == v4) self.assertFalse(v2 == v4)
self.assertFalse(v1 == v5) self.assertFalse(v1 == v5)
self.assertFalse(v1 == v6) self.assertFalse(v1 == v6)
# this is done as Dense and Sparse matrices can be semantically
# equal while still implementing a different __eq__ method
self.assertEqual(dm1, sm1)
self.assertEqual(sm1, dm1)
def test_equals(self): def test_equals(self):
indices = [1, 2, 4] indices = [1, 2, 4]