diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java index 5ca4371285..6af45aec3c 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java @@ -21,16 +21,18 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Iterator; +import java.util.HashSet; import java.util.List; import java.util.NoSuchElementException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.function.BiConsumer; +import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; import com.google.common.base.Objects; import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; import org.apache.spark.annotation.Private; @@ -43,7 +45,7 @@ import org.apache.spark.annotation.Private; public class InMemoryStore implements KVStore { private Object metadata; - private ConcurrentMap, InstanceList> data = new ConcurrentHashMap<>(); + private InMemoryLists inMemoryLists = new InMemoryLists(); @Override public T getMetadata(Class klass) { @@ -57,13 +59,13 @@ public class InMemoryStore implements KVStore { @Override public long count(Class type) { - InstanceList list = data.get(type); + InstanceList list = inMemoryLists.get(type); return list != null ? list.size() : 0; } @Override public long count(Class type, String index, Object indexedValue) throws Exception { - InstanceList list = data.get(type); + InstanceList list = inMemoryLists.get(type); int count = 0; Object comparable = asKey(indexedValue); KVTypeInfo.Accessor accessor = list.getIndexAccessor(index); @@ -77,29 +79,22 @@ public class InMemoryStore implements KVStore { @Override public T read(Class klass, Object naturalKey) { - InstanceList list = data.get(klass); - Object value = list != null ? list.get(naturalKey) : null; + InstanceList list = inMemoryLists.get(klass); + T value = list != null ? list.get(naturalKey) : null; if (value == null) { throw new NoSuchElementException(); } - return klass.cast(value); + return value; } @Override public void write(Object value) throws Exception { - InstanceList list = data.computeIfAbsent(value.getClass(), key -> { - try { - return new InstanceList(key); - } catch (Exception e) { - throw Throwables.propagate(e); - } - }); - list.put(value); + inMemoryLists.write(value); } @Override public void delete(Class type, Object naturalKey) { - InstanceList list = data.get(type); + InstanceList list = inMemoryLists.get(type); if (list != null) { list.delete(naturalKey); } @@ -107,15 +102,28 @@ public class InMemoryStore implements KVStore { @Override public KVStoreView view(Class type){ - InstanceList list = data.get(type); - return list != null ? list.view(type) - : new InMemoryView<>(type, Collections.emptyList(), null); + InstanceList list = inMemoryLists.get(type); + return list != null ? list.view() : emptyView(); } @Override public void close() { metadata = null; - data.clear(); + inMemoryLists.clear(); + } + + @Override + public boolean removeAllByIndexValues( + Class klass, + String index, + Collection indexValues) { + InstanceList list = inMemoryLists.get(klass); + + if (list != null) { + return list.countingRemoveAllByIndexValues(index, indexValues) > 0; + } else { + return false; + } } @SuppressWarnings("unchecked") @@ -126,64 +134,150 @@ public class InMemoryStore implements KVStore { return (Comparable) in; } - private static class InstanceList { + @SuppressWarnings("unchecked") + private static KVStoreView emptyView() { + return (InMemoryView) InMemoryView.EMPTY_VIEW; + } + + /** + * Encapsulates ConcurrentHashMap so that the typing in and out of the map strictly maps a + * class of type T to an InstanceList of type T. + */ + private static class InMemoryLists { + private final ConcurrentMap, InstanceList> data = new ConcurrentHashMap<>(); + + @SuppressWarnings("unchecked") + public InstanceList get(Class type) { + return (InstanceList) data.get(type); + } + + @SuppressWarnings("unchecked") + public void write(T value) throws Exception { + InstanceList list = + (InstanceList) data.computeIfAbsent(value.getClass(), InstanceList::new); + list.put(value); + } + + public void clear() { + data.clear(); + } + } + + private static class InstanceList { + + /** + * A BiConsumer to control multi-entity removal. We use this in a forEach rather than an + * iterator because there is a bug in jdk8 which affects remove() on all concurrent map + * iterators. https://bugs.openjdk.java.net/browse/JDK-8078645 + */ + private static class CountingRemoveIfForEach implements BiConsumer, T> { + private final ConcurrentMap, T> data; + private final Predicate filter; + + /** + * Keeps a count of the number of elements removed. This count is not currently surfaced + * to clients of KVStore as Java's generic removeAll() construct returns only a boolean, + * but I found it handy to have the count of elements removed while debugging; a count being + * no more complicated than a boolean, I've retained that behavior here, even though there + * is no current requirement. + */ + private int count = 0; + + CountingRemoveIfForEach( + ConcurrentMap, T> data, + Predicate filter) { + this.data = data; + this.filter = filter; + } + + @Override + public void accept(Comparable key, T value) { + if (filter.test(value)) { + if (data.remove(key, value)) { + count++; + } + } + } + + public int count() { return count; } + } private final KVTypeInfo ti; private final KVTypeInfo.Accessor naturalKey; - private final ConcurrentMap, Object> data; + private final ConcurrentMap, T> data; - private int size; - - private InstanceList(Class type) throws Exception { - this.ti = new KVTypeInfo(type); + private InstanceList(Class klass) { + this.ti = new KVTypeInfo(klass); this.naturalKey = ti.getAccessor(KVIndex.NATURAL_INDEX_NAME); this.data = new ConcurrentHashMap<>(); - this.size = 0; } KVTypeInfo.Accessor getIndexAccessor(String indexName) { return ti.getAccessor(indexName); } - public Object get(Object key) { + int countingRemoveAllByIndexValues(String index, Collection indexValues) { + Predicate filter = getPredicate(ti.getAccessor(index), indexValues); + CountingRemoveIfForEach callback = new CountingRemoveIfForEach<>(data, filter); + + data.forEach(callback); + return callback.count(); + } + + public T get(Object key) { return data.get(asKey(key)); } - public void put(Object value) throws Exception { - Preconditions.checkArgument(ti.type().equals(value.getClass()), - "Unexpected type: %s", value.getClass()); - if (data.put(asKey(naturalKey.get(value)), value) == null) { - size++; - } + public void put(T value) throws Exception { + data.put(asKey(naturalKey.get(value)), value); } public void delete(Object key) { - if (data.remove(asKey(key)) != null) { - size--; - } + data.remove(asKey(key)); } public int size() { - return size; + return data.size(); } - @SuppressWarnings("unchecked") - public InMemoryView view(Class type) { - Preconditions.checkArgument(ti.type().equals(type), "Unexpected type: %s", type); - Collection all = (Collection) data.values(); - return new InMemoryView<>(type, all, ti); + public InMemoryView view() { + return new InMemoryView<>(data.values(), ti); } + private static Predicate getPredicate( + KVTypeInfo.Accessor getter, + Collection values) { + if (Comparable.class.isAssignableFrom(getter.getType())) { + HashSet set = new HashSet<>(values); + + return (value) -> set.contains(indexValueForEntity(getter, value)); + } else { + HashSet set = new HashSet<>(values.size()); + for (Object key : values) { + set.add(asKey(key)); + } + return (value) -> set.contains(asKey(indexValueForEntity(getter, value))); + } + } + + private static Object indexValueForEntity(KVTypeInfo.Accessor getter, Object entity) { + try { + return getter.get(entity); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } } private static class InMemoryView extends KVStoreView { + private static final InMemoryView EMPTY_VIEW = + new InMemoryView<>(Collections.emptyList(), null); private final Collection elements; private final KVTypeInfo ti; private final KVTypeInfo.Accessor natural; - InMemoryView(Class type, Collection elements, KVTypeInfo ti) { - super(type); + InMemoryView(Collection elements, KVTypeInfo ti) { this.elements = elements; this.ti = ti; this.natural = ti != null ? ti.getAccessor(KVIndex.NATURAL_INDEX_NAME) : null; @@ -195,34 +289,32 @@ public class InMemoryStore implements KVStore { return new InMemoryIterator<>(elements.iterator()); } - try { - KVTypeInfo.Accessor getter = index != null ? ti.getAccessor(index) : null; - int modifier = ascending ? 1 : -1; + KVTypeInfo.Accessor getter = index != null ? ti.getAccessor(index) : null; + int modifier = ascending ? 1 : -1; - final List sorted = copyElements(); - Collections.sort(sorted, (e1, e2) -> modifier * compare(e1, e2, getter)); - Stream stream = sorted.stream(); + final List sorted = copyElements(); + sorted.sort((e1, e2) -> modifier * compare(e1, e2, getter)); + Stream stream = sorted.stream(); - if (first != null) { - stream = stream.filter(e -> modifier * compare(e, getter, first) >= 0); - } - - if (last != null) { - stream = stream.filter(e -> modifier * compare(e, getter, last) <= 0); - } - - if (skip > 0) { - stream = stream.skip(skip); - } - - if (max < sorted.size()) { - stream = stream.limit((int) max); - } - - return new InMemoryIterator<>(stream.iterator()); - } catch (Exception e) { - throw Throwables.propagate(e); + if (first != null) { + Comparable firstKey = asKey(first); + stream = stream.filter(e -> modifier * compare(e, getter, firstKey) >= 0); } + + if (last != null) { + Comparable lastKey = asKey(last); + stream = stream.filter(e -> modifier * compare(e, getter, lastKey) <= 0); + } + + if (skip > 0) { + stream = stream.skip(skip); + } + + if (max < sorted.size()) { + stream = stream.limit((int) max); + } + + return new InMemoryIterator<>(stream.iterator()); } /** @@ -232,9 +324,10 @@ public class InMemoryStore implements KVStore { if (parent != null) { KVTypeInfo.Accessor parentGetter = ti.getParentAccessor(index); Preconditions.checkArgument(parentGetter != null, "Parent filter for non-child index."); + Comparable parentKey = asKey(parent); return elements.stream() - .filter(e -> compare(e, parentGetter, parent) == 0) + .filter(e -> compare(e, parentGetter, parentKey) == 0) .collect(Collectors.toList()); } else { return new ArrayList<>(elements); @@ -243,24 +336,23 @@ public class InMemoryStore implements KVStore { private int compare(T e1, T e2, KVTypeInfo.Accessor getter) { try { - int diff = compare(e1, getter, getter.get(e2)); + int diff = compare(e1, getter, asKey(getter.get(e2))); if (diff == 0 && getter != natural) { - diff = compare(e1, natural, natural.get(e2)); + diff = compare(e1, natural, asKey(natural.get(e2))); } return diff; - } catch (Exception e) { - throw Throwables.propagate(e); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); } } - private int compare(T e1, KVTypeInfo.Accessor getter, Object v2) { + private int compare(T e1, KVTypeInfo.Accessor getter, Comparable v2) { try { - return asKey(getter.get(e1)).compareTo(asKey(v2)); - } catch (Exception e) { - throw Throwables.propagate(e); + return asKey(getter.get(e1)).compareTo(v2); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); } } - } private static class InMemoryIterator implements KVStoreIterator { diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStore.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStore.java index 72d06a8ca8..ac159eb431 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStore.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStore.java @@ -18,6 +18,7 @@ package org.apache.spark.util.kvstore; import java.io.Closeable; +import java.util.Collection; import org.apache.spark.annotation.Private; @@ -126,4 +127,9 @@ public interface KVStore extends Closeable { */ long count(Class type, String index, Object indexedValue) throws Exception; + /** + * A cheaper way to remove multiple items from the KVStore + */ + boolean removeAllByIndexValues(Class klass, String index, Collection indexValues) + throws Exception; } diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreView.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreView.java index 8ea79bbe16..90135268fd 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreView.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreView.java @@ -38,8 +38,6 @@ import org.apache.spark.annotation.Private; @Private public abstract class KVStoreView implements Iterable { - final Class type; - boolean ascending = true; String index = KVIndex.NATURAL_INDEX_NAME; Object first = null; @@ -48,10 +46,6 @@ public abstract class KVStoreView implements Iterable { long skip = 0L; long max = Long.MAX_VALUE; - public KVStoreView(Class type) { - this.type = type; - } - /** * Reverses the order of iteration. By default, iterates in ascending order. */ diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java index 870b484f99..b8c5fab870 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java @@ -37,7 +37,7 @@ public class KVTypeInfo { private final Map indices; private final Map accessors; - public KVTypeInfo(Class type) throws Exception { + public KVTypeInfo(Class type) { this.type = type; this.accessors = new HashMap<>(); this.indices = new HashMap<>(); @@ -122,8 +122,9 @@ public class KVTypeInfo { */ interface Accessor { - Object get(Object instance) throws Exception; + Object get(Object instance) throws ReflectiveOperationException; + Class getType(); } private class FieldAccessor implements Accessor { @@ -135,10 +136,14 @@ public class KVTypeInfo { } @Override - public Object get(Object instance) throws Exception { + public Object get(Object instance) throws ReflectiveOperationException { return field.get(instance); } + @Override + public Class getType() { + return field.getType(); + } } private class MethodAccessor implements Accessor { @@ -150,10 +155,14 @@ public class KVTypeInfo { } @Override - public Object get(Object instance) throws Exception { + public Object get(Object instance) throws ReflectiveOperationException { return method.invoke(instance); } + @Override + public Class getType() { + return method.getReturnType(); + } } } diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 58e2a8f25f..2ca4b0b2cb 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -19,10 +19,7 @@ package org.apache.spark.util.kvstore; import java.io.File; import java.io.IOException; -import java.util.HashMap; -import java.util.Iterator; -import java.util.Map; -import java.util.NoSuchElementException; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicReference; @@ -188,11 +185,11 @@ public class LevelDB implements KVStore { @Override public KVStoreView view(Class type) throws Exception { - return new KVStoreView(type) { + return new KVStoreView() { @Override public Iterator iterator() { try { - return new LevelDBIterator<>(LevelDB.this, this); + return new LevelDBIterator<>(type, LevelDB.this, this); } catch (Exception e) { throw Throwables.propagate(e); } @@ -200,6 +197,26 @@ public class LevelDB implements KVStore { }; } + @Override + public boolean removeAllByIndexValues( + Class klass, + String index, + Collection indexValues) throws Exception { + LevelDBTypeInfo.Index naturalIndex = getTypeInfo(klass).naturalIndex(); + boolean removed = false; + KVStoreView view = view(klass).index(index); + + for (Object indexValue : indexValues) { + for (T value: view.first(indexValue).last(indexValue)) { + Object itemKey = naturalIndex.getValue(value); + delete(klass, itemKey); + removed = true; + } + } + + return removed; + } + @Override public long count(Class type) throws Exception { LevelDBTypeInfo.Index idx = getTypeInfo(type).naturalIndex(); diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index e3efc92c4a..94e8c9fc57 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -45,11 +45,11 @@ class LevelDBIterator implements KVStoreIterator { private boolean closed; private long count; - LevelDBIterator(LevelDB db, KVStoreView params) throws Exception { + LevelDBIterator(Class type, LevelDB db, KVStoreView params) throws Exception { this.db = db; this.ascending = params.ascending; this.it = db.db().iterator(); - this.type = params.type; + this.type = type; this.ti = db.getTypeInfo(type); this.index = ti.index(params.index); this.max = params.max; @@ -207,47 +207,43 @@ class LevelDBIterator implements KVStoreIterator { return null; } - try { - while (true) { - boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); - if (!hasNext) { - return null; - } - - Map.Entry nextEntry; - try { - // Avoid races if another thread is updating the DB. - nextEntry = ascending ? it.next() : it.prev(); - } catch (NoSuchElementException e) { - return null; - } - - byte[] nextKey = nextEntry.getKey(); - // Next key is not part of the index, stop. - if (!startsWith(nextKey, indexKeyPrefix)) { - return null; - } - - // If the next key is an end marker, then skip it. - if (isEndMarker(nextKey)) { - continue; - } - - // If there's a known end key and iteration has gone past it, stop. - if (end != null) { - int comp = compare(nextKey, end) * (ascending ? 1 : -1); - if (comp > 0) { - return null; - } - } - - count++; - - // Next element is part of the iteration, return it. - return nextEntry.getValue(); + while (true) { + boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); + if (!hasNext) { + return null; } - } catch (Exception e) { - throw Throwables.propagate(e); + + Map.Entry nextEntry; + try { + // Avoid races if another thread is updating the DB. + nextEntry = ascending ? it.next() : it.prev(); + } catch (NoSuchElementException e) { + return null; + } + + byte[] nextKey = nextEntry.getKey(); + // Next key is not part of the index, stop. + if (!startsWith(nextKey, indexKeyPrefix)) { + return null; + } + + // If the next key is an end marker, then skip it. + if (isEndMarker(nextKey)) { + continue; + } + + // If there's a known end key and iteration has gone past it, stop. + if (end != null) { + int comp = compare(nextKey, end) * (ascending ? 1 : -1); + if (comp > 0) { + return null; + } + } + + count++; + + // Next element is part of the iteration, return it. + return nextEntry.getValue(); } } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java index 9abf26f02f..9e34225e14 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java @@ -19,6 +19,7 @@ package org.apache.spark.util.kvstore; import java.util.NoSuchElementException; +import com.google.common.collect.ImmutableSet; import org.junit.Test; import static org.junit.Assert.*; @@ -132,6 +133,51 @@ public class InMemoryStoreSuite { assertEquals(o, store.view(ArrayKeyIndexType.class).index("id").first(o.id).iterator().next()); } + @Test + public void testRemoveAll() throws Exception { + KVStore store = new InMemoryStore(); + + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 2; j++) { + ArrayKeyIndexType o = new ArrayKeyIndexType(); + o.key = new int[] { i, j, 0 }; + o.id = new String[] { "things" }; + store.write(o); + + o = new ArrayKeyIndexType(); + o.key = new int[] { i, j, 1 }; + o.id = new String[] { "more things" }; + store.write(o); + } + } + + ArrayKeyIndexType o = new ArrayKeyIndexType(); + o.key = new int[] { 2, 2, 2 }; + o.id = new String[] { "things" }; + store.write(o); + + assertEquals(9, store.count(ArrayKeyIndexType.class)); + + + store.removeAllByIndexValues( + ArrayKeyIndexType.class, + KVIndex.NATURAL_INDEX_NAME, + ImmutableSet.of(new int[] {0, 0, 0}, new int[] { 2, 2, 2 })); + assertEquals(7, store.count(ArrayKeyIndexType.class)); + + store.removeAllByIndexValues( + ArrayKeyIndexType.class, + "id", + ImmutableSet.of(new String [] { "things" })); + assertEquals(4, store.count(ArrayKeyIndexType.class)); + + store.removeAllByIndexValues( + ArrayKeyIndexType.class, + "id", + ImmutableSet.of(new String [] { "more things" })); + assertEquals(0, store.count(ArrayKeyIndexType.class)); + } + @Test public void testBasicIteration() throws Exception { KVStore store = new InMemoryStore(); diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index 39a952f2b0..0b755ba0e8 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -24,6 +24,7 @@ import java.util.NoSuchElementException; import java.util.stream.Collectors; import java.util.stream.StreamSupport; +import com.google.common.collect.ImmutableSet; import org.apache.commons.io.FileUtils; import org.iq80.leveldb.DBIterator; import org.junit.After; @@ -198,6 +199,48 @@ public class LevelDBSuite { assertEquals(0, db.count(t.getClass(), "name", "name")); } + @Test + public void testRemoveAll() throws Exception { + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 2; j++) { + ArrayKeyIndexType o = new ArrayKeyIndexType(); + o.key = new int[] { i, j, 0 }; + o.id = new String[] { "things" }; + db.write(o); + + o = new ArrayKeyIndexType(); + o.key = new int[] { i, j, 1 }; + o.id = new String[] { "more things" }; + db.write(o); + } + } + + ArrayKeyIndexType o = new ArrayKeyIndexType(); + o.key = new int[] { 2, 2, 2 }; + o.id = new String[] { "things" }; + db.write(o); + + assertEquals(9, db.count(ArrayKeyIndexType.class)); + + db.removeAllByIndexValues( + ArrayKeyIndexType.class, + KVIndex.NATURAL_INDEX_NAME, + ImmutableSet.of(new int[] {0, 0, 0}, new int[] { 2, 2, 2 })); + assertEquals(7, db.count(ArrayKeyIndexType.class)); + + db.removeAllByIndexValues( + ArrayKeyIndexType.class, + "id", + ImmutableSet.of(new String[] { "things" })); + assertEquals(4, db.count(ArrayKeyIndexType.class)); + + db.removeAllByIndexValues( + ArrayKeyIndexType.class, + "id", + ImmutableSet.of(new String[] { "more things" })); + assertEquals(0, db.count(ArrayKeyIndexType.class)); + } + @Test public void testSkip() throws Exception { for (int i = 0; i < 10; i++) { diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index b085f21f2d..0052fd42d2 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -700,6 +700,10 @@ private[spark] class AppStatusListener( val now = System.nanoTime() stage.info = event.stageInfo + // We have to update the stage status AFTER we create all the executorSummaries + // because stage deletion deletes whatever summaries it finds when the status is completed. + stage.executorSummaries.values.foreach(update(_, now)) + // Because of SPARK-20205, old event logs may contain valid stages without a submission time // in their start event. In those cases, we can only detect whether a stage was skipped by // waiting until the completion event, at which point the field would have been set. @@ -728,8 +732,6 @@ private[spark] class AppStatusListener( update(pool, now) } - stage.executorSummaries.values.foreach(update(_, now)) - val executorIdsForStage = stage.blackListedExecutors executorIdsForStage.foreach { executorId => liveExecutors.get(executorId).foreach { exec => @@ -1142,20 +1144,10 @@ private[spark] class AppStatusListener( s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING } - stages.foreach { s => + val stageIds = stages.map { s => val key = Array(s.info.stageId, s.info.attemptId) kvstore.delete(s.getClass(), key) - val execSummaries = kvstore.view(classOf[ExecutorStageSummaryWrapper]) - .index("stage") - .first(key) - .last(key) - .asScala - .toSeq - execSummaries.foreach { e => - kvstore.delete(e.getClass(), e.id) - } - // Check whether there are remaining attempts for the same stage. If there aren't, then // also delete the RDD graph data. val remainingAttempts = kvstore.view(classOf[StageDataWrapper]) @@ -1177,16 +1169,14 @@ private[spark] class AppStatusListener( } cleanupCachedQuantiles(key) + key } + // Delete summaries in one pass, as deleting them for each stage is slow + kvstore.removeAllByIndexValues(classOf[ExecutorStageSummaryWrapper], "stage", stageIds) + // Delete tasks for all stages in one pass, as deleting them for each stage individually is slow - val tasks = kvstore.view(classOf[TaskDataWrapper]).asScala - val keys = stages.map { s => (s.info.stageId, s.info.attemptId) }.toSet - tasks.foreach { t => - if (keys.contains((t.stageId, t.stageAttemptId))) { - kvstore.delete(t.getClass(), t.taskId) - } - } + kvstore.removeAllByIndexValues(classOf[TaskDataWrapper], TaskIndexNames.STAGE, stageIds) } private def cleanupTasks(stage: LiveStage): Unit = { diff --git a/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala b/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala index 5ec7d90bfa..38cb030297 100644 --- a/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala +++ b/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala @@ -17,14 +17,18 @@ package org.apache.spark.status +import java.util.Collection import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer} import com.google.common.util.concurrent.MoreExecutors import org.apache.spark.SparkConf import org.apache.spark.internal.config.Status._ +import org.apache.spark.status.ElementTrackingStore._ import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.util.kvstore._ @@ -46,7 +50,27 @@ import org.apache.spark.util.kvstore._ */ private[spark] class ElementTrackingStore(store: KVStore, conf: SparkConf) extends KVStore { - private val triggers = new HashMap[Class[_], Seq[Trigger[_]]]() + private class LatchedTriggers(val triggers: Seq[Trigger[_]]) { + private val pending = new AtomicBoolean(false) + + def fireOnce(f: Seq[Trigger[_]] => Unit): WriteQueueResult = { + if (pending.compareAndSet(false, true)) { + doAsync { + pending.set(false) + f(triggers) + } + WriteQueued + } else { + WriteSkippedQueue + } + } + + def :+(addlTrigger: Trigger[_]): LatchedTriggers = { + new LatchedTriggers(triggers :+ addlTrigger) + } + } + + private val triggers = new HashMap[Class[_], LatchedTriggers]() private val flushTriggers = new ListBuffer[() => Unit]() private val executor = if (conf.get(ASYNC_TRACKING_ENABLED)) { ThreadUtils.newDaemonSingleThreadExecutor("element-tracking-store-worker") @@ -66,8 +90,13 @@ private[spark] class ElementTrackingStore(store: KVStore, conf: SparkConf) exten * of elements of the registered type currently known to be in the store. */ def addTrigger(klass: Class[_], threshold: Long)(action: Long => Unit): Unit = { - val existing = triggers.getOrElse(klass, Seq()) - triggers(klass) = existing :+ Trigger(threshold, action) + val newTrigger = Trigger(threshold, action) + triggers.get(klass) match { + case None => + triggers(klass) = new LatchedTriggers(Seq(newTrigger)) + case Some(latchedTrigger) => + triggers(klass) = latchedTrigger :+ newTrigger + } } /** @@ -96,23 +125,35 @@ private[spark] class ElementTrackingStore(store: KVStore, conf: SparkConf) exten override def write(value: Any): Unit = store.write(value) /** Write an element to the store, optionally checking for whether to fire triggers. */ - def write(value: Any, checkTriggers: Boolean): Unit = { + def write(value: Any, checkTriggers: Boolean): WriteQueueResult = { write(value) if (checkTriggers && !stopped) { - triggers.get(value.getClass()).foreach { list => - doAsync { - val count = store.count(value.getClass()) + triggers.get(value.getClass).map { latchedList => + latchedList.fireOnce { list => + val count = store.count(value.getClass) list.foreach { t => if (count > t.threshold) { t.action(count) } } } - } + }.getOrElse(WriteSkippedQueue) + } else { + WriteSkippedQueue } } + def removeAllByIndexValues[T](klass: Class[T], index: String, indexValues: Iterable[_]): Boolean = + removeAllByIndexValues(klass, index, indexValues.asJavaCollection) + + override def removeAllByIndexValues[T]( + klass: Class[T], + index: String, + indexValues: Collection[_]): Boolean = { + store.removeAllByIndexValues(klass, index, indexValues) + } + override def delete(klass: Class[_], naturalKey: Any): Unit = store.delete(klass, naturalKey) override def getMetadata[T](klass: Class[T]): T = store.getMetadata(klass) @@ -157,3 +198,14 @@ private[spark] class ElementTrackingStore(store: KVStore, conf: SparkConf) exten action: Long => Unit) } + +private[spark] object ElementTrackingStore { + /** + * This trait is solely to assist testing the correctness of single-fire execution + * The result of write() is otherwise unused. + */ + sealed trait WriteQueueResult + + object WriteQueued extends WriteQueueResult + object WriteSkippedQueue extends WriteQueueResult +} diff --git a/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala index a99c1ec7e1..38e88e6a01 100644 --- a/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala @@ -17,13 +17,60 @@ package org.apache.spark.status +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} + import org.mockito.Mockito._ +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Eventually import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config.Status._ +import org.apache.spark.status.ElementTrackingStore._ import org.apache.spark.util.kvstore._ -class ElementTrackingStoreSuite extends SparkFunSuite { +class ElementTrackingStoreSuite extends SparkFunSuite with Eventually { + + test("asynchronous tracking single-fire") { + val store = mock(classOf[KVStore]) + val tracking = new ElementTrackingStore(store, new SparkConf() + .set(ASYNC_TRACKING_ENABLED, true)) + + var done = new AtomicBoolean(false) + var type1 = new AtomicInteger(0) + var queued0: WriteQueueResult = null + var queued1: WriteQueueResult = null + var queued2: WriteQueueResult = null + var queued3: WriteQueueResult = null + + tracking.addTrigger(classOf[Type1], 1) { count => + val count = type1.getAndIncrement() + + count match { + case 0 => + // while in the asynchronous thread, attempt to increment twice. The first should + // succeed, the second should be skipped + queued1 = tracking.write(new Type1, checkTriggers = true) + queued2 = tracking.write(new Type1, checkTriggers = true) + case 1 => + // Verify that once we've started deliver again, that we can enqueue another + queued3 = tracking.write(new Type1, checkTriggers = true) + case 2 => + done.set(true) + } + } + + when(store.count(classOf[Type1])).thenReturn(2L) + queued0 = tracking.write(new Type1, checkTriggers = true) + eventually { + done.get() shouldEqual true + } + + tracking.close(false) + assert(queued0 == WriteQueued) + assert(queued1 == WriteQueued) + assert(queued2 == WriteSkippedQueue) + assert(queued3 == WriteQueued) + } test("tracking for multiple types") { val store = mock(classOf[KVStore])