[SPARK-9792] Make DenseMatrix equality semantical
Before, you could have this code ``` A = SparseMatrix(2, 2, [0, 2, 3], [0], [2]) B = DenseMatrix(2, 2, [2, 0, 0, 0]) B == A # False A == B # True ``` The second would be `True` as `SparseMatrix` already checks for semantic equality. This commit changes `DenseMatrix` so that equality is semantical as well. ## What changes were proposed in this pull request? Better semantic equality for DenseMatrix ## How was this patch tested? Unit tests were added, plus manual testing. Note that the code falls back to the old behavior when `other` is not a SparseMatrix. Closes #17968 from gglanzani/SPARK-9792. Authored-by: Giovanni Lanzani <giovanni@lanzani.nl> Signed-off-by: Holden Karau <holden@pigscanfly.ca>
This commit is contained in:
parent
5888b15d9c
commit
92530c7db1
|
@ -980,14 +980,14 @@ class DenseMatrix(Matrix):
|
|||
return self.values[i + j * self.numRows]
|
||||
|
||||
def __eq__(self, other):
|
||||
if (not isinstance(other, DenseMatrix) or
|
||||
self.numRows != other.numRows or
|
||||
self.numCols != other.numCols):
|
||||
if (self.numRows != other.numRows or self.numCols != other.numCols):
|
||||
return False
|
||||
if isinstance(other, SparseMatrix):
|
||||
return np.all(self.toArray() == other.toArray())
|
||||
|
||||
self_values = np.ravel(self.toArray(), order='F')
|
||||
other_values = np.ravel(other.toArray(), order='F')
|
||||
return all(self_values == other_values)
|
||||
return np.all(self_values == other_values)
|
||||
|
||||
|
||||
class SparseMatrix(Matrix):
|
||||
|
|
|
@ -112,11 +112,17 @@ class VectorTests(MLlibTestCase):
|
|||
v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
|
||||
v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
|
||||
v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
|
||||
dm1 = DenseMatrix(2, 2, [2, 0, 0, 0])
|
||||
sm1 = SparseMatrix(2, 2, [0, 2, 3], [0], [2])
|
||||
self.assertEqual(v1, v2)
|
||||
self.assertEqual(v1, v3)
|
||||
self.assertFalse(v2 == v4)
|
||||
self.assertFalse(v1 == v5)
|
||||
self.assertFalse(v1 == v6)
|
||||
# this is done as Dense and Sparse matrices can be semantically
|
||||
# equal while still implementing a different __eq__ method
|
||||
self.assertEqual(dm1, sm1)
|
||||
self.assertEqual(sm1, dm1)
|
||||
|
||||
def test_equals(self):
|
||||
indices = [1, 2, 4]
|
||||
|
|
|
@ -1135,14 +1135,14 @@ class DenseMatrix(Matrix):
|
|||
return self.values[i + j * self.numRows]
|
||||
|
||||
def __eq__(self, other):
|
||||
if (not isinstance(other, DenseMatrix) or
|
||||
self.numRows != other.numRows or
|
||||
self.numCols != other.numCols):
|
||||
if (self.numRows != other.numRows or self.numCols != other.numCols):
|
||||
return False
|
||||
if isinstance(other, SparseMatrix):
|
||||
return np.all(self.toArray() == other.toArray())
|
||||
|
||||
self_values = np.ravel(self.toArray(), order='F')
|
||||
other_values = np.ravel(other.toArray(), order='F')
|
||||
return all(self_values == other_values)
|
||||
return np.all(self_values == other_values)
|
||||
|
||||
|
||||
class SparseMatrix(Matrix):
|
||||
|
|
|
@ -115,11 +115,17 @@ class VectorTests(MLlibTestCase):
|
|||
v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
|
||||
v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
|
||||
v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
|
||||
dm1 = DenseMatrix(2, 2, [2, 0, 0, 0])
|
||||
sm1 = SparseMatrix(2, 2, [0, 2, 3], [0], [2])
|
||||
self.assertEqual(v1, v2)
|
||||
self.assertEqual(v1, v3)
|
||||
self.assertFalse(v2 == v4)
|
||||
self.assertFalse(v1 == v5)
|
||||
self.assertFalse(v1 == v6)
|
||||
# this is done as Dense and Sparse matrices can be semantically
|
||||
# equal while still implementing a different __eq__ method
|
||||
self.assertEqual(dm1, sm1)
|
||||
self.assertEqual(sm1, dm1)
|
||||
|
||||
def test_equals(self):
|
||||
indices = [1, 2, 4]
|
||||
|
|
Loading…
Reference in a new issue