[SPARK-7140] [MLLIB] only scan the first 16 entries in Vector.hashCode
The Python SerDe calls `Object.hashCode`, which is very expensive for Vectors. It is not necessary to scan the whole vector, especially for large ones. In this PR, we only scan the first 16 nonzeros. srowen Author: Xiangrui Meng <meng@databricks.com> Closes #5697 from mengxr/SPARK-7140 and squashes the following commits: 2abc86d [Xiangrui Meng] typo 8fb7d74 [Xiangrui Meng] update impl 1ebad60 [Xiangrui Meng] only scan the first 16 nonzeros in Vector.hashCode
This commit is contained in:
parent
6a827d5d1e
commit
b14cd23649
|
@ -52,7 +52,7 @@ sealed trait Vector extends Serializable {
|
||||||
|
|
||||||
override def equals(other: Any): Boolean = {
|
override def equals(other: Any): Boolean = {
|
||||||
other match {
|
other match {
|
||||||
case v2: Vector => {
|
case v2: Vector =>
|
||||||
if (this.size != v2.size) return false
|
if (this.size != v2.size) return false
|
||||||
(this, v2) match {
|
(this, v2) match {
|
||||||
case (s1: SparseVector, s2: SparseVector) =>
|
case (s1: SparseVector, s2: SparseVector) =>
|
||||||
|
@ -63,20 +63,28 @@ sealed trait Vector extends Serializable {
|
||||||
Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values)
|
Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values)
|
||||||
case (_, _) => util.Arrays.equals(this.toArray, v2.toArray)
|
case (_, _) => util.Arrays.equals(this.toArray, v2.toArray)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
case _ => false
|
case _ => false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a hash code value for the vector. The hash code is based on its size and its nonzeros
|
||||||
|
* in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]].
|
||||||
|
*/
|
||||||
override def hashCode(): Int = {
|
override def hashCode(): Int = {
|
||||||
var result: Int = size + 31
|
// This is a reference implementation. It calls return in foreachActive, which is slow.
|
||||||
this.foreachActive { case (index, value) =>
|
// Subclasses should override it with optimized implementation.
|
||||||
// ignore explict 0 for comparison between sparse and dense
|
var result: Int = 31 + size
|
||||||
if (value != 0) {
|
this.foreachActive { (index, value) =>
|
||||||
result = 31 * result + index
|
if (index < 16) {
|
||||||
// refer to {@link java.util.Arrays.equals} for hash algorithm
|
// ignore explicit 0 for comparison between sparse and dense
|
||||||
val bits = java.lang.Double.doubleToLongBits(value)
|
if (value != 0) {
|
||||||
result = 31 * result + (bits ^ (bits >>> 32)).toInt
|
result = 31 * result + index
|
||||||
|
val bits = java.lang.Double.doubleToLongBits(value)
|
||||||
|
result = 31 * result + (bits ^ (bits >>> 32)).toInt
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
|
@ -317,7 +325,7 @@ object Vectors {
|
||||||
case SparseVector(n, ids, vs) => vs
|
case SparseVector(n, ids, vs) => vs
|
||||||
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
|
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
|
||||||
}
|
}
|
||||||
val size = values.size
|
val size = values.length
|
||||||
|
|
||||||
if (p == 1) {
|
if (p == 1) {
|
||||||
var sum = 0.0
|
var sum = 0.0
|
||||||
|
@ -371,8 +379,8 @@ object Vectors {
|
||||||
val v1Indices = v1.indices
|
val v1Indices = v1.indices
|
||||||
val v2Values = v2.values
|
val v2Values = v2.values
|
||||||
val v2Indices = v2.indices
|
val v2Indices = v2.indices
|
||||||
val nnzv1 = v1Indices.size
|
val nnzv1 = v1Indices.length
|
||||||
val nnzv2 = v2Indices.size
|
val nnzv2 = v2Indices.length
|
||||||
|
|
||||||
var kv1 = 0
|
var kv1 = 0
|
||||||
var kv2 = 0
|
var kv2 = 0
|
||||||
|
@ -401,7 +409,7 @@ object Vectors {
|
||||||
|
|
||||||
case (DenseVector(vv1), DenseVector(vv2)) =>
|
case (DenseVector(vv1), DenseVector(vv2)) =>
|
||||||
var kv = 0
|
var kv = 0
|
||||||
val sz = vv1.size
|
val sz = vv1.length
|
||||||
while (kv < sz) {
|
while (kv < sz) {
|
||||||
val score = vv1(kv) - vv2(kv)
|
val score = vv1(kv) - vv2(kv)
|
||||||
squaredDistance += score * score
|
squaredDistance += score * score
|
||||||
|
@ -422,7 +430,7 @@ object Vectors {
|
||||||
var kv2 = 0
|
var kv2 = 0
|
||||||
val indices = v1.indices
|
val indices = v1.indices
|
||||||
var squaredDistance = 0.0
|
var squaredDistance = 0.0
|
||||||
val nnzv1 = indices.size
|
val nnzv1 = indices.length
|
||||||
val nnzv2 = v2.size
|
val nnzv2 = v2.size
|
||||||
var iv1 = if (nnzv1 > 0) indices(kv1) else -1
|
var iv1 = if (nnzv1 > 0) indices(kv1) else -1
|
||||||
|
|
||||||
|
@ -451,8 +459,8 @@ object Vectors {
|
||||||
v1Values: Array[Double],
|
v1Values: Array[Double],
|
||||||
v2Indices: IndexedSeq[Int],
|
v2Indices: IndexedSeq[Int],
|
||||||
v2Values: Array[Double]): Boolean = {
|
v2Values: Array[Double]): Boolean = {
|
||||||
val v1Size = v1Values.size
|
val v1Size = v1Values.length
|
||||||
val v2Size = v2Values.size
|
val v2Size = v2Values.length
|
||||||
var k1 = 0
|
var k1 = 0
|
||||||
var k2 = 0
|
var k2 = 0
|
||||||
var allEqual = true
|
var allEqual = true
|
||||||
|
@ -493,7 +501,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
|
||||||
|
|
||||||
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
|
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
|
||||||
var i = 0
|
var i = 0
|
||||||
val localValuesSize = values.size
|
val localValuesSize = values.length
|
||||||
val localValues = values
|
val localValues = values
|
||||||
|
|
||||||
while (i < localValuesSize) {
|
while (i < localValuesSize) {
|
||||||
|
@ -501,6 +509,22 @@ class DenseVector(val values: Array[Double]) extends Vector {
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def hashCode(): Int = {
|
||||||
|
var result: Int = 31 + size
|
||||||
|
var i = 0
|
||||||
|
val end = math.min(values.length, 16)
|
||||||
|
while (i < end) {
|
||||||
|
val v = values(i)
|
||||||
|
if (v != 0.0) {
|
||||||
|
result = 31 * result + i
|
||||||
|
val bits = java.lang.Double.doubleToLongBits(values(i))
|
||||||
|
result = 31 * result + (bits ^ (bits >>> 32)).toInt
|
||||||
|
}
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object DenseVector {
|
object DenseVector {
|
||||||
|
@ -522,8 +546,8 @@ class SparseVector(
|
||||||
val values: Array[Double]) extends Vector {
|
val values: Array[Double]) extends Vector {
|
||||||
|
|
||||||
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
|
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
|
||||||
s" indices match the dimension of the values. You provided ${indices.size} indices and " +
|
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
|
||||||
s" ${values.size} values.")
|
s" ${values.length} values.")
|
||||||
|
|
||||||
override def toString: String =
|
override def toString: String =
|
||||||
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
|
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
|
||||||
|
@ -547,7 +571,7 @@ class SparseVector(
|
||||||
|
|
||||||
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
|
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
|
||||||
var i = 0
|
var i = 0
|
||||||
val localValuesSize = values.size
|
val localValuesSize = values.length
|
||||||
val localIndices = indices
|
val localIndices = indices
|
||||||
val localValues = values
|
val localValues = values
|
||||||
|
|
||||||
|
@ -556,6 +580,28 @@ class SparseVector(
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def hashCode(): Int = {
|
||||||
|
var result: Int = 31 + size
|
||||||
|
val end = values.length
|
||||||
|
var continue = true
|
||||||
|
var k = 0
|
||||||
|
while ((k < end) & continue) {
|
||||||
|
val i = indices(k)
|
||||||
|
if (i < 16) {
|
||||||
|
val v = values(k)
|
||||||
|
if (v != 0.0) {
|
||||||
|
result = 31 * result + i
|
||||||
|
val bits = java.lang.Double.doubleToLongBits(v)
|
||||||
|
result = 31 * result + (bits ^ (bits >>> 32)).toInt
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
continue = false
|
||||||
|
}
|
||||||
|
k += 1
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object SparseVector {
|
object SparseVector {
|
||||||
|
|
Loading…
Reference in a new issue