[SPARK-12933][SQL] Initial implementation of Count-Min sketch
This PR adds an initial implementation of count min sketch, contained in a new module spark-sketch under `common/sketch`. The implementation is based on the [`CountMinSketch` class in stream-lib][1].
As required by the [design doc][2], spark-sketch should have no external dependency.
Two classes, `Murmur3_x86_32` and `Platform` are copied to spark-sketch from spark-unsafe for hashing facilities. They'll also be used in the upcoming bloom filter implementation.
The following features will be added in future follow-up PRs:
- Serialization support
- DataFrame API integration
[1]: aac6b4d23a/src/main/java/com/clearspring/analytics/stream/frequency/CountMinSketch.java
[2]: https://issues.apache.org/jira/secure/attachment/12782378/BloomFilterandCount-MinSketchinSpark2.0.pdf
Author: Cheng Lian <lian@databricks.com>
Closes #10851 from liancheng/count-min-sketch.
This commit is contained in:
parent
5af5a02160
commit
1c690ddafa
42
common/sketch/pom.xml
Normal file
42
common/sketch/pom.xml
Normal file
|
@ -0,0 +1,42 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
~ contributor license agreements. See the NOTICE file distributed with
|
||||
~ this work for additional information regarding copyright ownership.
|
||||
~ The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
~ (the "License"); you may not use this file except in compliance with
|
||||
~ the License. You may obtain a copy of the License at
|
||||
~
|
||||
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS,
|
||||
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
~ See the License for the specific language governing permissions and
|
||||
~ limitations under the License.
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-parent_2.10</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<relativePath>../../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sketch_2.10</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
<name>Spark Project Sketch</name>
|
||||
<url>http://spark.apache.org/</url>
|
||||
<properties>
|
||||
<sbt.project.name>sketch</sbt.project.name>
|
||||
</properties>
|
||||
|
||||
<build>
|
||||
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
|
||||
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
|
||||
</build>
|
||||
</project>
|
|
@ -0,0 +1,132 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util.sketch;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
|
||||
/**
|
||||
* A Count-Min sketch is a probabilistic data structure used for summarizing streams of data in
|
||||
* sub-linear space. Currently, supported data types include:
|
||||
* <ul>
|
||||
* <li>{@link Byte}</li>
|
||||
* <li>{@link Short}</li>
|
||||
* <li>{@link Integer}</li>
|
||||
* <li>{@link Long}</li>
|
||||
* <li>{@link String}</li>
|
||||
* </ul>
|
||||
* Each {@link CountMinSketch} is initialized with a random seed, and a pair
|
||||
* of parameters:
|
||||
* <ol>
|
||||
* <li>relative error (or {@code eps}), and
|
||||
* <li>confidence (or {@code delta})
|
||||
* </ol>
|
||||
* Suppose you want to estimate the number of times an element {@code x} has appeared in a data
|
||||
* stream so far. With probability {@code delta}, the estimate of this frequency is within the
|
||||
* range {@code true frequency <= estimate <= true frequency + eps * N}, where {@code N} is the
|
||||
* total count of items have appeared the the data stream so far.
|
||||
*
|
||||
* Under the cover, a {@link CountMinSketch} is essentially a two-dimensional {@code long} array
|
||||
* with depth {@code d} and width {@code w}, where
|
||||
* <ul>
|
||||
* <li>{@code d = ceil(2 / eps)}</li>
|
||||
* <li>{@code w = ceil(-log(1 - confidence) / log(2))}</li>
|
||||
* </ul>
|
||||
*
|
||||
* See http://www.eecs.harvard.edu/~michaelm/CS222/countmin.pdf for technical details,
|
||||
* including proofs of the estimates and error bounds used in this implementation.
|
||||
*
|
||||
* This implementation is largely based on the {@code CountMinSketch} class from stream-lib.
|
||||
*/
|
||||
abstract public class CountMinSketch {
|
||||
/**
|
||||
* Returns the relative error (or {@code eps}) of this {@link CountMinSketch}.
|
||||
*/
|
||||
public abstract double relativeError();
|
||||
|
||||
/**
|
||||
* Returns the confidence (or {@code delta}) of this {@link CountMinSketch}.
|
||||
*/
|
||||
public abstract double confidence();
|
||||
|
||||
/**
|
||||
* Depth of this {@link CountMinSketch}.
|
||||
*/
|
||||
public abstract int depth();
|
||||
|
||||
/**
|
||||
* Width of this {@link CountMinSketch}.
|
||||
*/
|
||||
public abstract int width();
|
||||
|
||||
/**
|
||||
* Total count of items added to this {@link CountMinSketch} so far.
|
||||
*/
|
||||
public abstract long totalCount();
|
||||
|
||||
/**
|
||||
* Adds 1 to {@code item}.
|
||||
*/
|
||||
public abstract void add(Object item);
|
||||
|
||||
/**
|
||||
* Adds {@code count} to {@code item}.
|
||||
*/
|
||||
public abstract void add(Object item, long count);
|
||||
|
||||
/**
|
||||
* Returns the estimated frequency of {@code item}.
|
||||
*/
|
||||
public abstract long estimateCount(Object item);
|
||||
|
||||
/**
|
||||
* Merges another {@link CountMinSketch} with this one in place.
|
||||
*
|
||||
* Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed
|
||||
* can be merged.
|
||||
*/
|
||||
public abstract CountMinSketch mergeInPlace(CountMinSketch other);
|
||||
|
||||
/**
|
||||
* Writes out this {@link CountMinSketch} to an output stream in binary format.
|
||||
*/
|
||||
public abstract void writeTo(OutputStream out);
|
||||
|
||||
/**
|
||||
* Reads in a {@link CountMinSketch} from an input stream.
|
||||
*/
|
||||
public static CountMinSketch readFrom(InputStream in) {
|
||||
throw new UnsupportedOperationException("Not implemented yet");
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random
|
||||
* {@code seed}.
|
||||
*/
|
||||
public static CountMinSketch create(int depth, int width, int seed) {
|
||||
return new CountMinSketchImpl(depth, width, seed);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link CountMinSketch} with given relative error ({@code eps}), {@code confidence},
|
||||
* and random {@code seed}.
|
||||
*/
|
||||
public static CountMinSketch create(double eps, double confidence, int seed) {
|
||||
return new CountMinSketchImpl(eps, confidence, seed);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,268 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util.sketch;
|
||||
|
||||
import java.io.OutputStream;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
|
||||
class CountMinSketchImpl extends CountMinSketch {
|
||||
public static final long PRIME_MODULUS = (1L << 31) - 1;
|
||||
|
||||
private int depth;
|
||||
private int width;
|
||||
private long[][] table;
|
||||
private long[] hashA;
|
||||
private long totalCount;
|
||||
private double eps;
|
||||
private double confidence;
|
||||
|
||||
public CountMinSketchImpl(int depth, int width, int seed) {
|
||||
this.depth = depth;
|
||||
this.width = width;
|
||||
this.eps = 2.0 / width;
|
||||
this.confidence = 1 - 1 / Math.pow(2, depth);
|
||||
initTablesWith(depth, width, seed);
|
||||
}
|
||||
|
||||
public CountMinSketchImpl(double eps, double confidence, int seed) {
|
||||
// 2/w = eps ; w = 2/eps
|
||||
// 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence)
|
||||
this.eps = eps;
|
||||
this.confidence = confidence;
|
||||
this.width = (int) Math.ceil(2 / eps);
|
||||
this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2));
|
||||
initTablesWith(depth, width, seed);
|
||||
}
|
||||
|
||||
private void initTablesWith(int depth, int width, int seed) {
|
||||
this.table = new long[depth][width];
|
||||
this.hashA = new long[depth];
|
||||
Random r = new Random(seed);
|
||||
// We're using a linear hash functions
|
||||
// of the form (a*x+b) mod p.
|
||||
// a,b are chosen independently for each hash function.
|
||||
// However we can set b = 0 as all it does is shift the results
|
||||
// without compromising their uniformity or independence with
|
||||
// the other hashes.
|
||||
for (int i = 0; i < depth; ++i) {
|
||||
hashA[i] = r.nextInt(Integer.MAX_VALUE);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double relativeError() {
|
||||
return eps;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double confidence() {
|
||||
return confidence;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int depth() {
|
||||
return depth;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int width() {
|
||||
return width;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long totalCount() {
|
||||
return totalCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(Object item) {
|
||||
add(item, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(Object item, long count) {
|
||||
if (item instanceof String) {
|
||||
addString((String) item, count);
|
||||
} else {
|
||||
long longValue;
|
||||
|
||||
if (item instanceof Long) {
|
||||
longValue = (Long) item;
|
||||
} else if (item instanceof Integer) {
|
||||
longValue = ((Integer) item).longValue();
|
||||
} else if (item instanceof Short) {
|
||||
longValue = ((Short) item).longValue();
|
||||
} else if (item instanceof Byte) {
|
||||
longValue = ((Byte) item).longValue();
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Support for " + item.getClass().getName() + " not implemented"
|
||||
);
|
||||
}
|
||||
|
||||
addLong(longValue, count);
|
||||
}
|
||||
}
|
||||
|
||||
private void addString(String item, long count) {
|
||||
if (count < 0) {
|
||||
throw new IllegalArgumentException("Negative increments not implemented");
|
||||
}
|
||||
|
||||
int[] buckets = getHashBuckets(item, depth, width);
|
||||
|
||||
for (int i = 0; i < depth; ++i) {
|
||||
table[i][buckets[i]] += count;
|
||||
}
|
||||
|
||||
totalCount += count;
|
||||
}
|
||||
|
||||
private void addLong(long item, long count) {
|
||||
if (count < 0) {
|
||||
throw new IllegalArgumentException("Negative increments not implemented");
|
||||
}
|
||||
|
||||
for (int i = 0; i < depth; ++i) {
|
||||
table[i][hash(item, i)] += count;
|
||||
}
|
||||
|
||||
totalCount += count;
|
||||
}
|
||||
|
||||
private int hash(long item, int count) {
|
||||
long hash = hashA[count] * item;
|
||||
// A super fast way of computing x mod 2^p-1
|
||||
// See http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf
|
||||
// page 149, right after Proposition 7.
|
||||
hash += hash >> 32;
|
||||
hash &= PRIME_MODULUS;
|
||||
// Doing "%" after (int) conversion is ~2x faster than %'ing longs.
|
||||
return ((int) hash) % width;
|
||||
}
|
||||
|
||||
private static int[] getHashBuckets(String key, int hashCount, int max) {
|
||||
byte[] b;
|
||||
try {
|
||||
b = key.getBytes("UTF-8");
|
||||
} catch (UnsupportedEncodingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return getHashBuckets(b, hashCount, max);
|
||||
}
|
||||
|
||||
private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
|
||||
int[] result = new int[hashCount];
|
||||
int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
|
||||
int hash2 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, hash1);
|
||||
for (int i = 0; i < hashCount; i++) {
|
||||
result[i] = Math.abs((hash1 + i * hash2) % max);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long estimateCount(Object item) {
|
||||
if (item instanceof String) {
|
||||
return estimateCountForStringItem((String) item);
|
||||
} else {
|
||||
long longValue;
|
||||
|
||||
if (item instanceof Long) {
|
||||
longValue = (Long) item;
|
||||
} else if (item instanceof Integer) {
|
||||
longValue = ((Integer) item).longValue();
|
||||
} else if (item instanceof Short) {
|
||||
longValue = ((Short) item).longValue();
|
||||
} else if (item instanceof Byte) {
|
||||
longValue = ((Byte) item).longValue();
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Support for " + item.getClass().getName() + " not implemented"
|
||||
);
|
||||
}
|
||||
|
||||
return estimateCountForLongItem(longValue);
|
||||
}
|
||||
}
|
||||
|
||||
private long estimateCountForLongItem(long item) {
|
||||
long res = Long.MAX_VALUE;
|
||||
for (int i = 0; i < depth; ++i) {
|
||||
res = Math.min(res, table[i][hash(item, i)]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
private long estimateCountForStringItem(String item) {
|
||||
long res = Long.MAX_VALUE;
|
||||
int[] buckets = getHashBuckets(item, depth, width);
|
||||
for (int i = 0; i < depth; ++i) {
|
||||
res = Math.min(res, table[i][buckets[i]]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CountMinSketch mergeInPlace(CountMinSketch other) {
|
||||
if (other == null) {
|
||||
throw new CMSMergeException("Cannot merge null estimator");
|
||||
}
|
||||
|
||||
if (!(other instanceof CountMinSketchImpl)) {
|
||||
throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName());
|
||||
}
|
||||
|
||||
CountMinSketchImpl that = (CountMinSketchImpl) other;
|
||||
|
||||
if (this.depth != that.depth) {
|
||||
throw new CMSMergeException("Cannot merge estimators of different depth");
|
||||
}
|
||||
|
||||
if (this.width != that.width) {
|
||||
throw new CMSMergeException("Cannot merge estimators of different width");
|
||||
}
|
||||
|
||||
if (!Arrays.equals(this.hashA, that.hashA)) {
|
||||
throw new CMSMergeException("Cannot merge estimators of different seed");
|
||||
}
|
||||
|
||||
for (int i = 0; i < this.table.length; ++i) {
|
||||
for (int j = 0; j < this.table[i].length; ++j) {
|
||||
this.table[i][j] = this.table[i][j] + that.table[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
this.totalCount += that.totalCount;
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(OutputStream out) {
|
||||
throw new UnsupportedOperationException("Not implemented yet");
|
||||
}
|
||||
|
||||
protected static class CMSMergeException extends RuntimeException {
|
||||
public CMSMergeException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util.sketch;
|
||||
|
||||
/**
|
||||
* 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction.
|
||||
*/
|
||||
// This class is duplicated from `org.apache.spark.unsafe.hash.Murmur3_x86_32` to make sure
|
||||
// spark-sketch has no external dependencies.
|
||||
final class Murmur3_x86_32 {
|
||||
private static final int C1 = 0xcc9e2d51;
|
||||
private static final int C2 = 0x1b873593;
|
||||
|
||||
private final int seed;
|
||||
|
||||
public Murmur3_x86_32(int seed) {
|
||||
this.seed = seed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Murmur3_32(seed=" + seed + ")";
|
||||
}
|
||||
|
||||
public int hashInt(int input) {
|
||||
return hashInt(input, seed);
|
||||
}
|
||||
|
||||
public static int hashInt(int input, int seed) {
|
||||
int k1 = mixK1(input);
|
||||
int h1 = mixH1(seed, k1);
|
||||
|
||||
return fmix(h1, 4);
|
||||
}
|
||||
|
||||
public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
|
||||
return hashUnsafeWords(base, offset, lengthInBytes, seed);
|
||||
}
|
||||
|
||||
public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
|
||||
// This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
|
||||
assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
|
||||
int h1 = hashBytesByInt(base, offset, lengthInBytes, seed);
|
||||
return fmix(h1, lengthInBytes);
|
||||
}
|
||||
|
||||
public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
|
||||
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
|
||||
int lengthAligned = lengthInBytes - lengthInBytes % 4;
|
||||
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
|
||||
for (int i = lengthAligned; i < lengthInBytes; i++) {
|
||||
int halfWord = Platform.getByte(base, offset + i);
|
||||
int k1 = mixK1(halfWord);
|
||||
h1 = mixH1(h1, k1);
|
||||
}
|
||||
return fmix(h1, lengthInBytes);
|
||||
}
|
||||
|
||||
private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
|
||||
assert (lengthInBytes % 4 == 0);
|
||||
int h1 = seed;
|
||||
for (int i = 0; i < lengthInBytes; i += 4) {
|
||||
int halfWord = Platform.getInt(base, offset + i);
|
||||
int k1 = mixK1(halfWord);
|
||||
h1 = mixH1(h1, k1);
|
||||
}
|
||||
return h1;
|
||||
}
|
||||
|
||||
public int hashLong(long input) {
|
||||
return hashLong(input, seed);
|
||||
}
|
||||
|
||||
public static int hashLong(long input, int seed) {
|
||||
int low = (int) input;
|
||||
int high = (int) (input >>> 32);
|
||||
|
||||
int k1 = mixK1(low);
|
||||
int h1 = mixH1(seed, k1);
|
||||
|
||||
k1 = mixK1(high);
|
||||
h1 = mixH1(h1, k1);
|
||||
|
||||
return fmix(h1, 8);
|
||||
}
|
||||
|
||||
private static int mixK1(int k1) {
|
||||
k1 *= C1;
|
||||
k1 = Integer.rotateLeft(k1, 15);
|
||||
k1 *= C2;
|
||||
return k1;
|
||||
}
|
||||
|
||||
private static int mixH1(int h1, int k1) {
|
||||
h1 ^= k1;
|
||||
h1 = Integer.rotateLeft(h1, 13);
|
||||
h1 = h1 * 5 + 0xe6546b64;
|
||||
return h1;
|
||||
}
|
||||
|
||||
// Finalization mix - force all bits of a hash block to avalanche
|
||||
private static int fmix(int h1, int length) {
|
||||
h1 ^= length;
|
||||
h1 ^= h1 >>> 16;
|
||||
h1 *= 0x85ebca6b;
|
||||
h1 ^= h1 >>> 13;
|
||||
h1 *= 0xc2b2ae35;
|
||||
h1 ^= h1 >>> 16;
|
||||
return h1;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,172 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util.sketch;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
|
||||
import sun.misc.Unsafe;
|
||||
|
||||
// This class is duplicated from `org.apache.spark.unsafe.Platform` to make sure spark-sketch has no
|
||||
// external dependencies.
|
||||
final class Platform {
|
||||
|
||||
private static final Unsafe _UNSAFE;
|
||||
|
||||
public static final int BYTE_ARRAY_OFFSET;
|
||||
|
||||
public static final int INT_ARRAY_OFFSET;
|
||||
|
||||
public static final int LONG_ARRAY_OFFSET;
|
||||
|
||||
public static final int DOUBLE_ARRAY_OFFSET;
|
||||
|
||||
public static int getInt(Object object, long offset) {
|
||||
return _UNSAFE.getInt(object, offset);
|
||||
}
|
||||
|
||||
public static void putInt(Object object, long offset, int value) {
|
||||
_UNSAFE.putInt(object, offset, value);
|
||||
}
|
||||
|
||||
public static boolean getBoolean(Object object, long offset) {
|
||||
return _UNSAFE.getBoolean(object, offset);
|
||||
}
|
||||
|
||||
public static void putBoolean(Object object, long offset, boolean value) {
|
||||
_UNSAFE.putBoolean(object, offset, value);
|
||||
}
|
||||
|
||||
public static byte getByte(Object object, long offset) {
|
||||
return _UNSAFE.getByte(object, offset);
|
||||
}
|
||||
|
||||
public static void putByte(Object object, long offset, byte value) {
|
||||
_UNSAFE.putByte(object, offset, value);
|
||||
}
|
||||
|
||||
public static short getShort(Object object, long offset) {
|
||||
return _UNSAFE.getShort(object, offset);
|
||||
}
|
||||
|
||||
public static void putShort(Object object, long offset, short value) {
|
||||
_UNSAFE.putShort(object, offset, value);
|
||||
}
|
||||
|
||||
public static long getLong(Object object, long offset) {
|
||||
return _UNSAFE.getLong(object, offset);
|
||||
}
|
||||
|
||||
public static void putLong(Object object, long offset, long value) {
|
||||
_UNSAFE.putLong(object, offset, value);
|
||||
}
|
||||
|
||||
public static float getFloat(Object object, long offset) {
|
||||
return _UNSAFE.getFloat(object, offset);
|
||||
}
|
||||
|
||||
public static void putFloat(Object object, long offset, float value) {
|
||||
_UNSAFE.putFloat(object, offset, value);
|
||||
}
|
||||
|
||||
public static double getDouble(Object object, long offset) {
|
||||
return _UNSAFE.getDouble(object, offset);
|
||||
}
|
||||
|
||||
public static void putDouble(Object object, long offset, double value) {
|
||||
_UNSAFE.putDouble(object, offset, value);
|
||||
}
|
||||
|
||||
public static Object getObjectVolatile(Object object, long offset) {
|
||||
return _UNSAFE.getObjectVolatile(object, offset);
|
||||
}
|
||||
|
||||
public static void putObjectVolatile(Object object, long offset, Object value) {
|
||||
_UNSAFE.putObjectVolatile(object, offset, value);
|
||||
}
|
||||
|
||||
public static long allocateMemory(long size) {
|
||||
return _UNSAFE.allocateMemory(size);
|
||||
}
|
||||
|
||||
public static void freeMemory(long address) {
|
||||
_UNSAFE.freeMemory(address);
|
||||
}
|
||||
|
||||
public static void copyMemory(
|
||||
Object src, long srcOffset, Object dst, long dstOffset, long length) {
|
||||
// Check if dstOffset is before or after srcOffset to determine if we should copy
|
||||
// forward or backwards. This is necessary in case src and dst overlap.
|
||||
if (dstOffset < srcOffset) {
|
||||
while (length > 0) {
|
||||
long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
|
||||
_UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
|
||||
length -= size;
|
||||
srcOffset += size;
|
||||
dstOffset += size;
|
||||
}
|
||||
} else {
|
||||
srcOffset += length;
|
||||
dstOffset += length;
|
||||
while (length > 0) {
|
||||
long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
|
||||
srcOffset -= size;
|
||||
dstOffset -= size;
|
||||
_UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
|
||||
length -= size;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Raises an exception bypassing compiler checks for checked exceptions.
|
||||
*/
|
||||
public static void throwException(Throwable t) {
|
||||
_UNSAFE.throwException(t);
|
||||
}
|
||||
|
||||
/**
|
||||
* Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to
|
||||
* allow safepoint polling during a large copy.
|
||||
*/
|
||||
private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L;
|
||||
|
||||
static {
|
||||
sun.misc.Unsafe unsafe;
|
||||
try {
|
||||
Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
|
||||
unsafeField.setAccessible(true);
|
||||
unsafe = (sun.misc.Unsafe) unsafeField.get(null);
|
||||
} catch (Throwable cause) {
|
||||
unsafe = null;
|
||||
}
|
||||
_UNSAFE = unsafe;
|
||||
|
||||
if (_UNSAFE != null) {
|
||||
BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class);
|
||||
INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
|
||||
LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class);
|
||||
DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
|
||||
} else {
|
||||
BYTE_ARRAY_OFFSET = 0;
|
||||
INT_ARRAY_OFFSET = 0;
|
||||
LONG_ARRAY_OFFSET = 0;
|
||||
DOUBLE_ARRAY_OFFSET = 0;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util.sketch
|
||||
|
||||
import scala.reflect.ClassTag
|
||||
import scala.util.Random
|
||||
|
||||
import org.scalatest.FunSuite // scalastyle:ignore funsuite
|
||||
|
||||
class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
|
||||
private val epsOfTotalCount = 0.0001
|
||||
|
||||
private val confidence = 0.99
|
||||
|
||||
private val seed = 42
|
||||
|
||||
def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
|
||||
test(s"accuracy - $typeName") {
|
||||
val r = new Random()
|
||||
|
||||
val numAllItems = 1000000
|
||||
val allItems = Array.fill(numAllItems)(itemGenerator(r))
|
||||
|
||||
val numSamples = numAllItems / 10
|
||||
val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
|
||||
|
||||
val exactFreq = {
|
||||
val sampledItems = sampledItemIndices.map(allItems)
|
||||
sampledItems.groupBy(identity).mapValues(_.length.toLong)
|
||||
}
|
||||
|
||||
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
|
||||
sampledItemIndices.foreach(i => sketch.add(allItems(i)))
|
||||
|
||||
val probCorrect = {
|
||||
val numErrors = allItems.map { item =>
|
||||
val count = exactFreq.getOrElse(item, 0L)
|
||||
val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
|
||||
if (ratio > epsOfTotalCount) 1 else 0
|
||||
}.sum
|
||||
|
||||
1D - numErrors.toDouble / numAllItems
|
||||
}
|
||||
|
||||
assert(
|
||||
probCorrect > confidence,
|
||||
s"Confidence not reached: required $confidence, reached $probCorrect"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
|
||||
test(s"mergeInPlace - $typeName") {
|
||||
val r = new Random()
|
||||
val numToMerge = 5
|
||||
val numItemsPerSketch = 100000
|
||||
val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
|
||||
itemGenerator(r)
|
||||
}
|
||||
|
||||
val sketches = perSketchItems.map { items =>
|
||||
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
|
||||
items.foreach(sketch.add)
|
||||
sketch
|
||||
}
|
||||
|
||||
val mergedSketch = sketches.reduce(_ mergeInPlace _)
|
||||
|
||||
val expectedSketch = {
|
||||
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
|
||||
perSketchItems.foreach(_.foreach(sketch.add))
|
||||
sketch
|
||||
}
|
||||
|
||||
perSketchItems.foreach {
|
||||
_.foreach { item =>
|
||||
assert(mergedSketch.estimateCount(item) === expectedSketch.estimateCount(item))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def testItemType[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
|
||||
testAccuracy[T](typeName)(itemGenerator)
|
||||
testMergeInPlace[T](typeName)(itemGenerator)
|
||||
}
|
||||
|
||||
testItemType[Byte]("Byte") { _.nextInt().toByte }
|
||||
|
||||
testItemType[Short]("Short") { _.nextInt().toShort }
|
||||
|
||||
testItemType[Int]("Int") { _.nextInt() }
|
||||
|
||||
testItemType[Long]("Long") { _.nextLong() }
|
||||
|
||||
testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
|
||||
}
|
|
@ -113,6 +113,18 @@ hive_thriftserver = Module(
|
|||
)
|
||||
|
||||
|
||||
sketch = Module(
|
||||
name="sketch",
|
||||
dependencies=[],
|
||||
source_file_regexes=[
|
||||
"common/sketch/",
|
||||
],
|
||||
sbt_test_goals=[
|
||||
"sketch/test"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
graphx = Module(
|
||||
name="graphx",
|
||||
dependencies=[],
|
||||
|
|
1
pom.xml
1
pom.xml
|
@ -86,6 +86,7 @@
|
|||
</mailingLists>
|
||||
|
||||
<modules>
|
||||
<module>common/sketch</module>
|
||||
<module>tags</module>
|
||||
<module>core</module>
|
||||
<module>graphx</module>
|
||||
|
|
|
@ -34,13 +34,24 @@ object BuildCommons {
|
|||
|
||||
private val buildLocation = file(".").getAbsoluteFile.getParentFile
|
||||
|
||||
val allProjects@Seq(catalyst, core, graphx, hive, hiveThriftServer, mllib, repl,
|
||||
sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka,
|
||||
streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) =
|
||||
Seq("catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
|
||||
"sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink",
|
||||
"streaming-flume", "streaming-akka", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
|
||||
"streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _))
|
||||
val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer) = Seq(
|
||||
"catalyst", "sql", "hive", "hive-thriftserver"
|
||||
).map(ProjectRef(buildLocation, _))
|
||||
|
||||
val streamingProjects@Seq(
|
||||
streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka, streamingMqtt,
|
||||
streamingTwitter, streamingZeromq
|
||||
) = Seq(
|
||||
"streaming", "streaming-flume-sink", "streaming-flume", "streaming-akka", "streaming-kafka",
|
||||
"streaming-mqtt", "streaming-twitter", "streaming-zeromq"
|
||||
).map(ProjectRef(buildLocation, _))
|
||||
|
||||
val allProjects@Seq(
|
||||
core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _*
|
||||
) = Seq(
|
||||
"core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe",
|
||||
"test-tags", "sketch"
|
||||
).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects
|
||||
|
||||
val optionallyEnabledProjects@Seq(yarn, java8Tests, sparkGangliaLgpl,
|
||||
streamingKinesisAsl, dockerIntegrationTests) =
|
||||
|
@ -232,11 +243,15 @@ object SparkBuild extends PomBuild {
|
|||
/* Enable tests settings for all projects except examples, assembly and tools */
|
||||
(allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
|
||||
|
||||
// TODO: remove streamingAkka from this list after 2.0.0
|
||||
allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl,
|
||||
networkCommon, networkShuffle, networkYarn, unsafe, streamingAkka, testTags).contains(x)).foreach {
|
||||
x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
|
||||
}
|
||||
// TODO: remove streamingAkka and sketch from this list after 2.0.0
|
||||
allProjects.filterNot { x =>
|
||||
Seq(
|
||||
spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn,
|
||||
unsafe, streamingAkka, testTags, sketch
|
||||
).contains(x)
|
||||
}.foreach { x =>
|
||||
enable(MimaBuild.mimaSettings(sparkHome, x))(x)
|
||||
}
|
||||
|
||||
/* Unsafe settings */
|
||||
enable(Unsafe.settings)(unsafe)
|
||||
|
|
Loading…
Reference in a new issue