[SPARK-24754][ML] Minhash integer overflow
## What changes were proposed in this pull request? Use longs in calculating min hash to avoid bias due to int overflow. ## How was this patch tested? Existing tests. Author: Sean Owen <srowen@gmail.com> Closes #21750 from srowen/SPARK-24754.
This commit is contained in:
parent
e1de34113e
commit
8aceb961c3
|
@ -66,7 +66,7 @@ class MinHashLSHModel private[ml](
|
|||
val elemsList = elems.toSparse.indices.toList
|
||||
val hashValues = randCoefficients.map { case (a, b) =>
|
||||
elemsList.map { elem: Int =>
|
||||
((1 + elem) * a + b) % MinHashLSH.HASH_PRIME
|
||||
((1L + elem) * a + b) % MinHashLSH.HASH_PRIME
|
||||
}.min.toDouble
|
||||
}
|
||||
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450
|
||||
|
|
|
@ -1294,14 +1294,14 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed,
|
|||
>>> mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=12345)
|
||||
>>> model = mh.fit(df)
|
||||
>>> model.transform(df).head()
|
||||
Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([-1638925...
|
||||
Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668...
|
||||
>>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),),
|
||||
... (4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),),
|
||||
... (5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)]
|
||||
>>> df2 = spark.createDataFrame(data2, ["id", "features"])
|
||||
>>> key = Vectors.sparse(6, [1, 2], [1.0, 1.0])
|
||||
>>> model.approxNearestNeighbors(df2, key, 1).collect()
|
||||
[Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([-163892...
|
||||
[Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([6179668...
|
||||
>>> model.approxSimilarityJoin(df, df2, 0.6, distCol="JaccardDistance").select(
|
||||
... col("datasetA.id").alias("idA"),
|
||||
... col("datasetB.id").alias("idB"),
|
||||
|
@ -1309,8 +1309,8 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed,
|
|||
+---+---+---------------+
|
||||
|idA|idB|JaccardDistance|
|
||||
+---+---+---------------+
|
||||
| 1| 4| 0.5|
|
||||
| 0| 5| 0.5|
|
||||
| 1| 4| 0.5|
|
||||
+---+---+---------------+
|
||||
...
|
||||
>>> mhPath = temp_path + "/mh"
|
||||
|
|
Loading…
Reference in a new issue