[SPARK-14739][PYSPARK] Fix Vectors parser bugs
## What changes were proposed in this pull request? The PySpark deserialization has a bug that shows while deserializing all zero sparse vectors. This fix filters out empty string tokens before casting, hence properly stringified SparseVectors successfully get parsed. ## How was this patch tested? Standard unit-tests similar to other methods. Author: Arash Parsa <arash@ip-192-168-50-106.ec2.internal> Author: Arash Parsa <arashpa@gmail.com> Author: Vishnu Prasad <vishnu667@gmail.com> Author: Vishnu Prasad S <vishnu667@gmail.com> Closes #12516 from arashpa/SPARK-14739.
This commit is contained in:
parent
8bd05c9db2
commit
2b8906c437
|
@ -293,7 +293,7 @@ class DenseVector(Vector):
|
|||
s = s[start + 1: end]
|
||||
|
||||
try:
|
||||
values = [float(val) for val in s.split(',')]
|
||||
values = [float(val) for val in s.split(',') if val]
|
||||
except ValueError:
|
||||
raise ValueError("Unable to parse values from %s" % s)
|
||||
return DenseVector(values)
|
||||
|
@ -586,7 +586,7 @@ class SparseVector(Vector):
|
|||
new_s = s[ind_start + 1: ind_end]
|
||||
ind_list = new_s.split(',')
|
||||
try:
|
||||
indices = [int(ind) for ind in ind_list]
|
||||
indices = [int(ind) for ind in ind_list if ind]
|
||||
except ValueError:
|
||||
raise ValueError("Unable to parse indices from %s." % new_s)
|
||||
s = s[ind_end + 1:].strip()
|
||||
|
@ -599,7 +599,7 @@ class SparseVector(Vector):
|
|||
raise ValueError("Values array should end with ']'.")
|
||||
val_list = s[val_start + 1: val_end].split(',')
|
||||
try:
|
||||
values = [float(val) for val in val_list]
|
||||
values = [float(val) for val in val_list if val]
|
||||
except ValueError:
|
||||
raise ValueError("Unable to parse values from %s." % s)
|
||||
return SparseVector(size, indices, values)
|
||||
|
|
|
@ -393,14 +393,20 @@ class VectorTests(MLlibTestCase):
|
|||
self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
|
||||
|
||||
def test_parse_vector(self):
|
||||
a = DenseVector([])
|
||||
self.assertEqual(str(a), '[]')
|
||||
self.assertEqual(Vectors.parse(str(a)), a)
|
||||
a = DenseVector([3, 4, 6, 7])
|
||||
self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]')
|
||||
self.assertTrue(Vectors.parse(str(a)), a)
|
||||
self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]')
|
||||
self.assertEqual(Vectors.parse(str(a)), a)
|
||||
a = SparseVector(4, [], [])
|
||||
self.assertEqual(str(a), '(4,[],[])')
|
||||
self.assertEqual(SparseVector.parse(str(a)), a)
|
||||
a = SparseVector(4, [0, 2], [3, 4])
|
||||
self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
|
||||
self.assertTrue(Vectors.parse(str(a)), a)
|
||||
self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])')
|
||||
self.assertEqual(Vectors.parse(str(a)), a)
|
||||
a = SparseVector(10, [0, 1], [4, 5])
|
||||
self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
|
||||
self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
|
||||
|
||||
def test_norms(self):
|
||||
a = DenseVector([0, 2, 3, -1])
|
||||
|
|
Loading…
Reference in a new issue