[SPARK-5186] [MLLIB] Vector.equals and Vector.hashCode are very inefficient
JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5186 Currently SparseVector is using the inherited equals from Vector, which will create a full-size array for even the sparse vector. The pull request contains a specialized equals optimization that improves on both time and space. 1. The implementation will be consistent with the original. Especially it will keep equality comparison between SparseVector and DenseVector. Author: Yuhao Yang <hhbyyh@gmail.com> Author: Yuhao Yang <yuhao@yuhaodevbox.sh.intel.com> Closes #3997 from hhbyyh/master and squashes the following commits: 0d9d130 [Yuhao Yang] function name change and ut update 93f0d46 [Yuhao Yang] unify sparse vs dense vectors 985e160 [Yuhao Yang] improve locality for equals bdf8789 [Yuhao Yang] improve equals and rewrite hashCode for Vector a6952c3 [Yuhao Yang] fix scala style for comments 50abef3 [Yuhao Yang] fix ut for sparse vector with explicit 0 f41b135 [Yuhao Yang] iterative equals for sparse vector 5741144 [Yuhao Yang] Specialized equals for SparseVector
This commit is contained in:
parent
d181c2a1fc
commit
2f82c841fa
|
@ -50,13 +50,35 @@ sealed trait Vector extends Serializable {
|
|||
|
||||
override def equals(other: Any): Boolean = {
|
||||
other match {
|
||||
case v: Vector =>
|
||||
util.Arrays.equals(this.toArray, v.toArray)
|
||||
case v2: Vector => {
|
||||
if (this.size != v2.size) return false
|
||||
(this, v2) match {
|
||||
case (s1: SparseVector, s2: SparseVector) =>
|
||||
Vectors.equals(s1.indices, s1.values, s2.indices, s2.values)
|
||||
case (s1: SparseVector, d1: DenseVector) =>
|
||||
Vectors.equals(s1.indices, s1.values, 0 until d1.size, d1.values)
|
||||
case (d1: DenseVector, s1: SparseVector) =>
|
||||
Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values)
|
||||
case (_, _) => util.Arrays.equals(this.toArray, v2.toArray)
|
||||
}
|
||||
}
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
override def hashCode(): Int = util.Arrays.hashCode(this.toArray)
|
||||
override def hashCode(): Int = {
|
||||
var result: Int = size + 31
|
||||
this.foreachActive { case (index, value) =>
|
||||
// ignore explict 0 for comparison between sparse and dense
|
||||
if (value != 0) {
|
||||
result = 31 * result + index
|
||||
// refer to {@link java.util.Arrays.equals} for hash algorithm
|
||||
val bits = java.lang.Double.doubleToLongBits(value)
|
||||
result = 31 * result + (bits ^ (bits >>> 32)).toInt
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts the instance to a breeze vector.
|
||||
|
@ -392,6 +414,33 @@ object Vectors {
|
|||
}
|
||||
squaredDistance
|
||||
}
|
||||
|
||||
/**
|
||||
* Check equality between sparse/dense vectors
|
||||
*/
|
||||
private[mllib] def equals(
|
||||
v1Indices: IndexedSeq[Int],
|
||||
v1Values: Array[Double],
|
||||
v2Indices: IndexedSeq[Int],
|
||||
v2Values: Array[Double]): Boolean = {
|
||||
val v1Size = v1Values.size
|
||||
val v2Size = v2Values.size
|
||||
var k1 = 0
|
||||
var k2 = 0
|
||||
var allEqual = true
|
||||
while (allEqual) {
|
||||
while (k1 < v1Size && v1Values(k1) == 0) k1 += 1
|
||||
while (k2 < v2Size && v2Values(k2) == 0) k2 += 1
|
||||
|
||||
if (k1 >= v1Size || k2 >= v2Size) {
|
||||
return k1 >= v1Size && k2 >= v2Size // check end alignment
|
||||
}
|
||||
allEqual = v1Indices(k1) == v2Indices(k2) && v1Values(k1) == v2Values(k2)
|
||||
k1 += 1
|
||||
k2 += 1
|
||||
}
|
||||
allEqual
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -89,6 +89,24 @@ class VectorsSuite extends FunSuite {
|
|||
}
|
||||
}
|
||||
|
||||
test("vectors equals with explicit 0") {
|
||||
val dv1 = Vectors.dense(Array(0, 0.9, 0, 0.8, 0))
|
||||
val sv1 = Vectors.sparse(5, Array(1, 3), Array(0.9, 0.8))
|
||||
val sv2 = Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(0, 0.9, 0, 0.8, 0))
|
||||
|
||||
val vectors = Seq(dv1, sv1, sv2)
|
||||
for (v <- vectors; u <- vectors) {
|
||||
assert(v === u)
|
||||
assert(v.## === u.##)
|
||||
}
|
||||
|
||||
val another = Vectors.sparse(5, Array(0, 1, 3), Array(0, 0.9, 0.2))
|
||||
for (v <- vectors) {
|
||||
assert(v != another)
|
||||
assert(v.## != another.##)
|
||||
}
|
||||
}
|
||||
|
||||
test("indexing dense vectors") {
|
||||
val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
|
||||
assert(vec(0) === 1.0)
|
||||
|
|
Loading…
Reference in a new issue