[SPARK-12933][SQL] Initial implementation of Count-Min sketch

This PR adds an initial implementation of count min sketch, contained in a new module spark-sketch under `common/sketch`. The implementation is based on the [`CountMinSketch` class in stream-lib][1].

As required by the [design doc][2], spark-sketch should have no external dependency.
Two classes, `Murmur3_x86_32` and `Platform` are copied to spark-sketch from spark-unsafe for hashing facilities. They'll also be used in the upcoming bloom filter implementation.

The following features will be added in future follow-up PRs:

- Serialization support
- DataFrame API integration

[1]: aac6b4d23a/src/main/java/com/clearspring/analytics/stream/frequency/CountMinSketch.java
[2]: https://issues.apache.org/jira/secure/attachment/12782378/BloomFilterandCount-MinSketchinSpark2.0.pdf

Author: Cheng Lian <lian@databricks.com>

Closes #10851 from liancheng/count-min-sketch.
This commit is contained in:
Cheng Lian 2016-01-23 00:34:55 -08:00 committed by Reynold Xin
parent 5af5a02160
commit 1c690ddafa
9 changed files with 892 additions and 12 deletions

42
common/sketch/pom.xml Normal file
View file

@ -0,0 +1,42 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Licensed to the Apache Software Foundation (ASF) under one or more
~ contributor license agreements. See the NOTICE file distributed with
~ this work for additional information regarding copyright ownership.
~ The ASF licenses this file to You under the Apache License, Version 2.0
~ (the "License"); you may not use this file except in compliance with
~ the License. You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent_2.10</artifactId>
<version>2.0.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sketch_2.10</artifactId>
<packaging>jar</packaging>
<name>Spark Project Sketch</name>
<url>http://spark.apache.org/</url>
<properties>
<sbt.project.name>sketch</sbt.project.name>
</properties>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
</build>
</project>

View file

@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.sketch;
import java.io.InputStream;
import java.io.OutputStream;
/**
* A Count-Min sketch is a probabilistic data structure used for summarizing streams of data in
* sub-linear space. Currently, supported data types include:
* <ul>
* <li>{@link Byte}</li>
* <li>{@link Short}</li>
* <li>{@link Integer}</li>
* <li>{@link Long}</li>
* <li>{@link String}</li>
* </ul>
* Each {@link CountMinSketch} is initialized with a random seed, and a pair
* of parameters:
* <ol>
* <li>relative error (or {@code eps}), and
* <li>confidence (or {@code delta})
* </ol>
* Suppose you want to estimate the number of times an element {@code x} has appeared in a data
* stream so far. With probability {@code delta}, the estimate of this frequency is within the
* range {@code true frequency <= estimate <= true frequency + eps * N}, where {@code N} is the
* total count of items have appeared the the data stream so far.
*
* Under the cover, a {@link CountMinSketch} is essentially a two-dimensional {@code long} array
* with depth {@code d} and width {@code w}, where
* <ul>
* <li>{@code d = ceil(2 / eps)}</li>
* <li>{@code w = ceil(-log(1 - confidence) / log(2))}</li>
* </ul>
*
* See http://www.eecs.harvard.edu/~michaelm/CS222/countmin.pdf for technical details,
* including proofs of the estimates and error bounds used in this implementation.
*
* This implementation is largely based on the {@code CountMinSketch} class from stream-lib.
*/
abstract public class CountMinSketch {
/**
* Returns the relative error (or {@code eps}) of this {@link CountMinSketch}.
*/
public abstract double relativeError();
/**
* Returns the confidence (or {@code delta}) of this {@link CountMinSketch}.
*/
public abstract double confidence();
/**
* Depth of this {@link CountMinSketch}.
*/
public abstract int depth();
/**
* Width of this {@link CountMinSketch}.
*/
public abstract int width();
/**
* Total count of items added to this {@link CountMinSketch} so far.
*/
public abstract long totalCount();
/**
* Adds 1 to {@code item}.
*/
public abstract void add(Object item);
/**
* Adds {@code count} to {@code item}.
*/
public abstract void add(Object item, long count);
/**
* Returns the estimated frequency of {@code item}.
*/
public abstract long estimateCount(Object item);
/**
* Merges another {@link CountMinSketch} with this one in place.
*
* Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed
* can be merged.
*/
public abstract CountMinSketch mergeInPlace(CountMinSketch other);
/**
* Writes out this {@link CountMinSketch} to an output stream in binary format.
*/
public abstract void writeTo(OutputStream out);
/**
* Reads in a {@link CountMinSketch} from an input stream.
*/
public static CountMinSketch readFrom(InputStream in) {
throw new UnsupportedOperationException("Not implemented yet");
}
/**
* Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random
* {@code seed}.
*/
public static CountMinSketch create(int depth, int width, int seed) {
return new CountMinSketchImpl(depth, width, seed);
}
/**
* Creates a {@link CountMinSketch} with given relative error ({@code eps}), {@code confidence},
* and random {@code seed}.
*/
public static CountMinSketch create(double eps, double confidence, int seed) {
return new CountMinSketchImpl(eps, confidence, seed);
}
}

View file

@ -0,0 +1,268 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.sketch;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;
class CountMinSketchImpl extends CountMinSketch {
public static final long PRIME_MODULUS = (1L << 31) - 1;
private int depth;
private int width;
private long[][] table;
private long[] hashA;
private long totalCount;
private double eps;
private double confidence;
public CountMinSketchImpl(int depth, int width, int seed) {
this.depth = depth;
this.width = width;
this.eps = 2.0 / width;
this.confidence = 1 - 1 / Math.pow(2, depth);
initTablesWith(depth, width, seed);
}
public CountMinSketchImpl(double eps, double confidence, int seed) {
// 2/w = eps ; w = 2/eps
// 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence)
this.eps = eps;
this.confidence = confidence;
this.width = (int) Math.ceil(2 / eps);
this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2));
initTablesWith(depth, width, seed);
}
private void initTablesWith(int depth, int width, int seed) {
this.table = new long[depth][width];
this.hashA = new long[depth];
Random r = new Random(seed);
// We're using a linear hash functions
// of the form (a*x+b) mod p.
// a,b are chosen independently for each hash function.
// However we can set b = 0 as all it does is shift the results
// without compromising their uniformity or independence with
// the other hashes.
for (int i = 0; i < depth; ++i) {
hashA[i] = r.nextInt(Integer.MAX_VALUE);
}
}
@Override
public double relativeError() {
return eps;
}
@Override
public double confidence() {
return confidence;
}
@Override
public int depth() {
return depth;
}
@Override
public int width() {
return width;
}
@Override
public long totalCount() {
return totalCount;
}
@Override
public void add(Object item) {
add(item, 1);
}
@Override
public void add(Object item, long count) {
if (item instanceof String) {
addString((String) item, count);
} else {
long longValue;
if (item instanceof Long) {
longValue = (Long) item;
} else if (item instanceof Integer) {
longValue = ((Integer) item).longValue();
} else if (item instanceof Short) {
longValue = ((Short) item).longValue();
} else if (item instanceof Byte) {
longValue = ((Byte) item).longValue();
} else {
throw new IllegalArgumentException(
"Support for " + item.getClass().getName() + " not implemented"
);
}
addLong(longValue, count);
}
}
private void addString(String item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}
int[] buckets = getHashBuckets(item, depth, width);
for (int i = 0; i < depth; ++i) {
table[i][buckets[i]] += count;
}
totalCount += count;
}
private void addLong(long item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}
for (int i = 0; i < depth; ++i) {
table[i][hash(item, i)] += count;
}
totalCount += count;
}
private int hash(long item, int count) {
long hash = hashA[count] * item;
// A super fast way of computing x mod 2^p-1
// See http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf
// page 149, right after Proposition 7.
hash += hash >> 32;
hash &= PRIME_MODULUS;
// Doing "%" after (int) conversion is ~2x faster than %'ing longs.
return ((int) hash) % width;
}
private static int[] getHashBuckets(String key, int hashCount, int max) {
byte[] b;
try {
b = key.getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
return getHashBuckets(b, hashCount, max);
}
private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
int[] result = new int[hashCount];
int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
int hash2 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, hash1);
for (int i = 0; i < hashCount; i++) {
result[i] = Math.abs((hash1 + i * hash2) % max);
}
return result;
}
@Override
public long estimateCount(Object item) {
if (item instanceof String) {
return estimateCountForStringItem((String) item);
} else {
long longValue;
if (item instanceof Long) {
longValue = (Long) item;
} else if (item instanceof Integer) {
longValue = ((Integer) item).longValue();
} else if (item instanceof Short) {
longValue = ((Short) item).longValue();
} else if (item instanceof Byte) {
longValue = ((Byte) item).longValue();
} else {
throw new IllegalArgumentException(
"Support for " + item.getClass().getName() + " not implemented"
);
}
return estimateCountForLongItem(longValue);
}
}
private long estimateCountForLongItem(long item) {
long res = Long.MAX_VALUE;
for (int i = 0; i < depth; ++i) {
res = Math.min(res, table[i][hash(item, i)]);
}
return res;
}
private long estimateCountForStringItem(String item) {
long res = Long.MAX_VALUE;
int[] buckets = getHashBuckets(item, depth, width);
for (int i = 0; i < depth; ++i) {
res = Math.min(res, table[i][buckets[i]]);
}
return res;
}
@Override
public CountMinSketch mergeInPlace(CountMinSketch other) {
if (other == null) {
throw new CMSMergeException("Cannot merge null estimator");
}
if (!(other instanceof CountMinSketchImpl)) {
throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName());
}
CountMinSketchImpl that = (CountMinSketchImpl) other;
if (this.depth != that.depth) {
throw new CMSMergeException("Cannot merge estimators of different depth");
}
if (this.width != that.width) {
throw new CMSMergeException("Cannot merge estimators of different width");
}
if (!Arrays.equals(this.hashA, that.hashA)) {
throw new CMSMergeException("Cannot merge estimators of different seed");
}
for (int i = 0; i < this.table.length; ++i) {
for (int j = 0; j < this.table[i].length; ++j) {
this.table[i][j] = this.table[i][j] + that.table[i][j];
}
}
this.totalCount += that.totalCount;
return this;
}
@Override
public void writeTo(OutputStream out) {
throw new UnsupportedOperationException("Not implemented yet");
}
protected static class CMSMergeException extends RuntimeException {
public CMSMergeException(String message) {
super(message);
}
}
}

View file

@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.sketch;
/**
* 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction.
*/
// This class is duplicated from `org.apache.spark.unsafe.hash.Murmur3_x86_32` to make sure
// spark-sketch has no external dependencies.
final class Murmur3_x86_32 {
private static final int C1 = 0xcc9e2d51;
private static final int C2 = 0x1b873593;
private final int seed;
public Murmur3_x86_32(int seed) {
this.seed = seed;
}
@Override
public String toString() {
return "Murmur3_32(seed=" + seed + ")";
}
public int hashInt(int input) {
return hashInt(input, seed);
}
public static int hashInt(int input, int seed) {
int k1 = mixK1(input);
int h1 = mixH1(seed, k1);
return fmix(h1, 4);
}
public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
return hashUnsafeWords(base, offset, lengthInBytes, seed);
}
public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
// This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
int h1 = hashBytesByInt(base, offset, lengthInBytes, seed);
return fmix(h1, lengthInBytes);
}
public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
for (int i = lengthAligned; i < lengthInBytes; i++) {
int halfWord = Platform.getByte(base, offset + i);
int k1 = mixK1(halfWord);
h1 = mixH1(h1, k1);
}
return fmix(h1, lengthInBytes);
}
private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
assert (lengthInBytes % 4 == 0);
int h1 = seed;
for (int i = 0; i < lengthInBytes; i += 4) {
int halfWord = Platform.getInt(base, offset + i);
int k1 = mixK1(halfWord);
h1 = mixH1(h1, k1);
}
return h1;
}
public int hashLong(long input) {
return hashLong(input, seed);
}
public static int hashLong(long input, int seed) {
int low = (int) input;
int high = (int) (input >>> 32);
int k1 = mixK1(low);
int h1 = mixH1(seed, k1);
k1 = mixK1(high);
h1 = mixH1(h1, k1);
return fmix(h1, 8);
}
private static int mixK1(int k1) {
k1 *= C1;
k1 = Integer.rotateLeft(k1, 15);
k1 *= C2;
return k1;
}
private static int mixH1(int h1, int k1) {
h1 ^= k1;
h1 = Integer.rotateLeft(h1, 13);
h1 = h1 * 5 + 0xe6546b64;
return h1;
}
// Finalization mix - force all bits of a hash block to avalanche
private static int fmix(int h1, int length) {
h1 ^= length;
h1 ^= h1 >>> 16;
h1 *= 0x85ebca6b;
h1 ^= h1 >>> 13;
h1 *= 0xc2b2ae35;
h1 ^= h1 >>> 16;
return h1;
}
}

View file

@ -0,0 +1,172 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.sketch;
import java.lang.reflect.Field;
import sun.misc.Unsafe;
// This class is duplicated from `org.apache.spark.unsafe.Platform` to make sure spark-sketch has no
// external dependencies.
final class Platform {
private static final Unsafe _UNSAFE;
public static final int BYTE_ARRAY_OFFSET;
public static final int INT_ARRAY_OFFSET;
public static final int LONG_ARRAY_OFFSET;
public static final int DOUBLE_ARRAY_OFFSET;
public static int getInt(Object object, long offset) {
return _UNSAFE.getInt(object, offset);
}
public static void putInt(Object object, long offset, int value) {
_UNSAFE.putInt(object, offset, value);
}
public static boolean getBoolean(Object object, long offset) {
return _UNSAFE.getBoolean(object, offset);
}
public static void putBoolean(Object object, long offset, boolean value) {
_UNSAFE.putBoolean(object, offset, value);
}
public static byte getByte(Object object, long offset) {
return _UNSAFE.getByte(object, offset);
}
public static void putByte(Object object, long offset, byte value) {
_UNSAFE.putByte(object, offset, value);
}
public static short getShort(Object object, long offset) {
return _UNSAFE.getShort(object, offset);
}
public static void putShort(Object object, long offset, short value) {
_UNSAFE.putShort(object, offset, value);
}
public static long getLong(Object object, long offset) {
return _UNSAFE.getLong(object, offset);
}
public static void putLong(Object object, long offset, long value) {
_UNSAFE.putLong(object, offset, value);
}
public static float getFloat(Object object, long offset) {
return _UNSAFE.getFloat(object, offset);
}
public static void putFloat(Object object, long offset, float value) {
_UNSAFE.putFloat(object, offset, value);
}
public static double getDouble(Object object, long offset) {
return _UNSAFE.getDouble(object, offset);
}
public static void putDouble(Object object, long offset, double value) {
_UNSAFE.putDouble(object, offset, value);
}
public static Object getObjectVolatile(Object object, long offset) {
return _UNSAFE.getObjectVolatile(object, offset);
}
public static void putObjectVolatile(Object object, long offset, Object value) {
_UNSAFE.putObjectVolatile(object, offset, value);
}
public static long allocateMemory(long size) {
return _UNSAFE.allocateMemory(size);
}
public static void freeMemory(long address) {
_UNSAFE.freeMemory(address);
}
public static void copyMemory(
Object src, long srcOffset, Object dst, long dstOffset, long length) {
// Check if dstOffset is before or after srcOffset to determine if we should copy
// forward or backwards. This is necessary in case src and dst overlap.
if (dstOffset < srcOffset) {
while (length > 0) {
long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
_UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
length -= size;
srcOffset += size;
dstOffset += size;
}
} else {
srcOffset += length;
dstOffset += length;
while (length > 0) {
long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
srcOffset -= size;
dstOffset -= size;
_UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
length -= size;
}
}
}
/**
* Raises an exception bypassing compiler checks for checked exceptions.
*/
public static void throwException(Throwable t) {
_UNSAFE.throwException(t);
}
/**
* Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to
* allow safepoint polling during a large copy.
*/
private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L;
static {
sun.misc.Unsafe unsafe;
try {
Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
unsafeField.setAccessible(true);
unsafe = (sun.misc.Unsafe) unsafeField.get(null);
} catch (Throwable cause) {
unsafe = null;
}
_UNSAFE = unsafe;
if (_UNSAFE != null) {
BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class);
INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class);
DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
} else {
BYTE_ARRAY_OFFSET = 0;
INT_ARRAY_OFFSET = 0;
LONG_ARRAY_OFFSET = 0;
DOUBLE_ARRAY_OFFSET = 0;
}
}
}

View file

@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.sketch
import scala.reflect.ClassTag
import scala.util.Random
import org.scalatest.FunSuite // scalastyle:ignore funsuite
class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
private val epsOfTotalCount = 0.0001
private val confidence = 0.99
private val seed = 42
def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
test(s"accuracy - $typeName") {
val r = new Random()
val numAllItems = 1000000
val allItems = Array.fill(numAllItems)(itemGenerator(r))
val numSamples = numAllItems / 10
val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
val exactFreq = {
val sampledItems = sampledItemIndices.map(allItems)
sampledItems.groupBy(identity).mapValues(_.length.toLong)
}
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
sampledItemIndices.foreach(i => sketch.add(allItems(i)))
val probCorrect = {
val numErrors = allItems.map { item =>
val count = exactFreq.getOrElse(item, 0L)
val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
if (ratio > epsOfTotalCount) 1 else 0
}.sum
1D - numErrors.toDouble / numAllItems
}
assert(
probCorrect > confidence,
s"Confidence not reached: required $confidence, reached $probCorrect"
)
}
}
def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
test(s"mergeInPlace - $typeName") {
val r = new Random()
val numToMerge = 5
val numItemsPerSketch = 100000
val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
itemGenerator(r)
}
val sketches = perSketchItems.map { items =>
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
items.foreach(sketch.add)
sketch
}
val mergedSketch = sketches.reduce(_ mergeInPlace _)
val expectedSketch = {
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
perSketchItems.foreach(_.foreach(sketch.add))
sketch
}
perSketchItems.foreach {
_.foreach { item =>
assert(mergedSketch.estimateCount(item) === expectedSketch.estimateCount(item))
}
}
}
}
def testItemType[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
testAccuracy[T](typeName)(itemGenerator)
testMergeInPlace[T](typeName)(itemGenerator)
}
testItemType[Byte]("Byte") { _.nextInt().toByte }
testItemType[Short]("Short") { _.nextInt().toShort }
testItemType[Int]("Int") { _.nextInt() }
testItemType[Long]("Long") { _.nextLong() }
testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
}

View file

@ -113,6 +113,18 @@ hive_thriftserver = Module(
)
sketch = Module(
name="sketch",
dependencies=[],
source_file_regexes=[
"common/sketch/",
],
sbt_test_goals=[
"sketch/test"
]
)
graphx = Module(
name="graphx",
dependencies=[],

View file

@ -86,6 +86,7 @@
</mailingLists>
<modules>
<module>common/sketch</module>
<module>tags</module>
<module>core</module>
<module>graphx</module>

View file

@ -34,13 +34,24 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
val allProjects@Seq(catalyst, core, graphx, hive, hiveThriftServer, mllib, repl,
sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka,
streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) =
Seq("catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
"sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink",
"streaming-flume", "streaming-akka", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
"streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _))
val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer) = Seq(
"catalyst", "sql", "hive", "hive-thriftserver"
).map(ProjectRef(buildLocation, _))
val streamingProjects@Seq(
streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka, streamingMqtt,
streamingTwitter, streamingZeromq
) = Seq(
"streaming", "streaming-flume-sink", "streaming-flume", "streaming-akka", "streaming-kafka",
"streaming-mqtt", "streaming-twitter", "streaming-zeromq"
).map(ProjectRef(buildLocation, _))
val allProjects@Seq(
core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _*
) = Seq(
"core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe",
"test-tags", "sketch"
).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects
val optionallyEnabledProjects@Seq(yarn, java8Tests, sparkGangliaLgpl,
streamingKinesisAsl, dockerIntegrationTests) =
@ -232,11 +243,15 @@ object SparkBuild extends PomBuild {
/* Enable tests settings for all projects except examples, assembly and tools */
(allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
// TODO: remove streamingAkka from this list after 2.0.0
allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl,
networkCommon, networkShuffle, networkYarn, unsafe, streamingAkka, testTags).contains(x)).foreach {
x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
}
// TODO: remove streamingAkka and sketch from this list after 2.0.0
allProjects.filterNot { x =>
Seq(
spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn,
unsafe, streamingAkka, testTags, sketch
).contains(x)
}.foreach { x =>
enable(MimaBuild.mimaSettings(sparkHome, x))(x)
}
/* Unsafe settings */
enable(Unsafe.settings)(unsafe)