[SPARK-9792] Make DenseMatrix equality semantical
Before, you could have this code ``` A = SparseMatrix(2, 2, [0, 2, 3], [0], [2]) B = DenseMatrix(2, 2, [2, 0, 0, 0]) B == A # False A == B # True ``` The second would be `True` as `SparseMatrix` already checks for semantic equality. This commit changes `DenseMatrix` so that equality is semantical as well. ## What changes were proposed in this pull request? Better semantic equality for DenseMatrix ## How was this patch tested? Unit tests were added, plus manual testing. Note that the code falls back to the old behavior when `other` is not a SparseMatrix. Closes #17968 from gglanzani/SPARK-9792. Authored-by: Giovanni Lanzani <giovanni@lanzani.nl> Signed-off-by: Holden Karau <holden@pigscanfly.ca>
This commit is contained in:
parent
5888b15d9c
commit
92530c7db1
|
@ -980,14 +980,14 @@ class DenseMatrix(Matrix):
|
||||||
return self.values[i + j * self.numRows]
|
return self.values[i + j * self.numRows]
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if (not isinstance(other, DenseMatrix) or
|
if (self.numRows != other.numRows or self.numCols != other.numCols):
|
||||||
self.numRows != other.numRows or
|
|
||||||
self.numCols != other.numCols):
|
|
||||||
return False
|
return False
|
||||||
|
if isinstance(other, SparseMatrix):
|
||||||
|
return np.all(self.toArray() == other.toArray())
|
||||||
|
|
||||||
self_values = np.ravel(self.toArray(), order='F')
|
self_values = np.ravel(self.toArray(), order='F')
|
||||||
other_values = np.ravel(other.toArray(), order='F')
|
other_values = np.ravel(other.toArray(), order='F')
|
||||||
return all(self_values == other_values)
|
return np.all(self_values == other_values)
|
||||||
|
|
||||||
|
|
||||||
class SparseMatrix(Matrix):
|
class SparseMatrix(Matrix):
|
||||||
|
|
|
@ -112,11 +112,17 @@ class VectorTests(MLlibTestCase):
|
||||||
v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
|
v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
|
||||||
v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
|
v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
|
||||||
v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
|
v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
|
||||||
|
dm1 = DenseMatrix(2, 2, [2, 0, 0, 0])
|
||||||
|
sm1 = SparseMatrix(2, 2, [0, 2, 3], [0], [2])
|
||||||
self.assertEqual(v1, v2)
|
self.assertEqual(v1, v2)
|
||||||
self.assertEqual(v1, v3)
|
self.assertEqual(v1, v3)
|
||||||
self.assertFalse(v2 == v4)
|
self.assertFalse(v2 == v4)
|
||||||
self.assertFalse(v1 == v5)
|
self.assertFalse(v1 == v5)
|
||||||
self.assertFalse(v1 == v6)
|
self.assertFalse(v1 == v6)
|
||||||
|
# this is done as Dense and Sparse matrices can be semantically
|
||||||
|
# equal while still implementing a different __eq__ method
|
||||||
|
self.assertEqual(dm1, sm1)
|
||||||
|
self.assertEqual(sm1, dm1)
|
||||||
|
|
||||||
def test_equals(self):
|
def test_equals(self):
|
||||||
indices = [1, 2, 4]
|
indices = [1, 2, 4]
|
||||||
|
|
|
@ -1135,14 +1135,14 @@ class DenseMatrix(Matrix):
|
||||||
return self.values[i + j * self.numRows]
|
return self.values[i + j * self.numRows]
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if (not isinstance(other, DenseMatrix) or
|
if (self.numRows != other.numRows or self.numCols != other.numCols):
|
||||||
self.numRows != other.numRows or
|
|
||||||
self.numCols != other.numCols):
|
|
||||||
return False
|
return False
|
||||||
|
if isinstance(other, SparseMatrix):
|
||||||
|
return np.all(self.toArray() == other.toArray())
|
||||||
|
|
||||||
self_values = np.ravel(self.toArray(), order='F')
|
self_values = np.ravel(self.toArray(), order='F')
|
||||||
other_values = np.ravel(other.toArray(), order='F')
|
other_values = np.ravel(other.toArray(), order='F')
|
||||||
return all(self_values == other_values)
|
return np.all(self_values == other_values)
|
||||||
|
|
||||||
|
|
||||||
class SparseMatrix(Matrix):
|
class SparseMatrix(Matrix):
|
||||||
|
|
|
@ -115,11 +115,17 @@ class VectorTests(MLlibTestCase):
|
||||||
v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
|
v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
|
||||||
v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
|
v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
|
||||||
v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
|
v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
|
||||||
|
dm1 = DenseMatrix(2, 2, [2, 0, 0, 0])
|
||||||
|
sm1 = SparseMatrix(2, 2, [0, 2, 3], [0], [2])
|
||||||
self.assertEqual(v1, v2)
|
self.assertEqual(v1, v2)
|
||||||
self.assertEqual(v1, v3)
|
self.assertEqual(v1, v3)
|
||||||
self.assertFalse(v2 == v4)
|
self.assertFalse(v2 == v4)
|
||||||
self.assertFalse(v1 == v5)
|
self.assertFalse(v1 == v5)
|
||||||
self.assertFalse(v1 == v6)
|
self.assertFalse(v1 == v6)
|
||||||
|
# this is done as Dense and Sparse matrices can be semantically
|
||||||
|
# equal while still implementing a different __eq__ method
|
||||||
|
self.assertEqual(dm1, sm1)
|
||||||
|
self.assertEqual(sm1, dm1)
|
||||||
|
|
||||||
def test_equals(self):
|
def test_equals(self):
|
||||||
indices = [1, 2, 4]
|
indices = [1, 2, 4]
|
||||||
|
|
Loading…
Reference in a new issue