Replace EmittedBatchCounter and UpdateCounter with ConcurrentAwaitableCounter (#5592)

* Replace EmittedBatchCounter and UpdateCounter with (both not safe for concurrent increments/updates) with ConcurrentAwaitableCounter (safe for concurrent increments)

* Fixes

* Fix EmitterTest

* Added Javadoc and make awaitCount() to throw exceptions on wrong count instead of masking errors
This commit is contained in:
Roman Leventov 2018-04-13 11:07:11 +07:00 committed by Gian Merlino
parent 882b172318
commit 124c89e435
8 changed files with 288 additions and 189 deletions

View File

@ -22,6 +22,7 @@ package io.druid.server.lookup.namespace.cache;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import com.google.inject.Inject;
import io.druid.concurrent.ConcurrentAwaitableCounter;
import io.druid.java.util.emitter.service.ServiceEmitter;
import io.druid.java.util.emitter.service.ServiceMetricEvent;
import io.druid.guice.LazySingleton;
@ -51,11 +52,11 @@ import java.util.concurrent.atomic.AtomicReference;
* // cacheState could be either NoCache or VersionedCache.
* if (cacheState instanceof NoCache) {
* // the cache is not yet created, or already closed
* } else if (cacheState instanceof VersionedCache) {
* } else {
* Map<String, String> cache = ((VersionedCache) cacheState).getCache(); // use the cache
* // Although VersionedCache implements AutoCloseable, versionedCache shouldn't be manually closed
* // when obtained from entry.getCacheState(). If the namespace updates should be ceased completely,
* // entry.close() (see below) should be called, it will close the last VersionedCache itself.
* // entry.close() (see below) should be called, it will close the last VersionedCache as well.
* // On scheduled updates, outdated VersionedCaches are also closed automatically.
* }
* ...
@ -105,14 +106,16 @@ public final class CacheScheduler
return impl.updaterFuture;
}
@VisibleForTesting
public void awaitTotalUpdates(int totalUpdates) throws InterruptedException
{
impl.updateCounter.awaitTotalUpdates(totalUpdates);
impl.updateCounter.awaitCount(totalUpdates);
}
@VisibleForTesting
void awaitNextUpdates(int nextUpdates) throws InterruptedException
{
impl.updateCounter.awaitNextUpdates(nextUpdates);
impl.updateCounter.awaitNextIncrements(nextUpdates);
}
/**
@ -145,7 +148,7 @@ public final class CacheScheduler
private final Future<?> updaterFuture;
private final Cleaner entryCleaner;
private final CacheGenerator<T> cacheGenerator;
private final UpdateCounter updateCounter = new UpdateCounter();
private final ConcurrentAwaitableCounter updateCounter = new ConcurrentAwaitableCounter();
private final CountDownLatch startLatch = new CountDownLatch(1);
private EntryImpl(final T namespace, final Entry<T> entry, final CacheGenerator<T> cacheGenerator)
@ -276,7 +279,7 @@ public final class CacheScheduler
return lastCacheState;
}
} while (!cacheStateHolder.compareAndSet(lastCacheState, newVersionedCache));
updateCounter.update();
updateCounter.increment();
return lastCacheState;
}
@ -485,7 +488,7 @@ public final class CacheScheduler
log.debug("Scheduled new %s", entry);
boolean success = false;
try {
success = entry.impl.updateCounter.awaitFirstUpdate(waitForFirstRunMs, TimeUnit.MILLISECONDS);
success = entry.impl.updateCounter.awaitFirstIncrement(waitForFirstRunMs, TimeUnit.MILLISECONDS);
if (success) {
return entry;
} else {

View File

@ -1,88 +0,0 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.server.lookup.namespace.cache;
import java.util.concurrent.Phaser;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
final class UpdateCounter
{
/**
* Max {@link Phaser}'s phase, specified in it's javadoc. Then it wraps to zero.
*/
private static final int MAX_PHASE = Integer.MAX_VALUE;
private final Phaser phaser = new Phaser(1);
void update()
{
phaser.arrive();
}
void awaitTotalUpdates(int totalUpdates) throws InterruptedException
{
totalUpdates &= MAX_PHASE;
int currentUpdates = phaser.getPhase();
checkNotTerminated(currentUpdates);
while (comparePhases(totalUpdates, currentUpdates) > 0) {
currentUpdates = phaser.awaitAdvanceInterruptibly(currentUpdates);
checkNotTerminated(currentUpdates);
}
}
private static int comparePhases(int phase1, int phase2)
{
int diff = (phase1 - phase2) & MAX_PHASE;
if (diff == 0) {
return 0;
}
return diff < MAX_PHASE / 2 ? 1 : -1;
}
private void checkNotTerminated(int phase)
{
if (phase < 0) {
throw new IllegalStateException("Phaser[" + phaser + "] unexpectedly terminated.");
}
}
void awaitNextUpdates(int nextUpdates) throws InterruptedException
{
if (nextUpdates <= 0) {
throw new IllegalArgumentException("nextUpdates is not positive: " + nextUpdates);
}
if (nextUpdates > MAX_PHASE / 4) {
throw new UnsupportedOperationException("Couldn't wait for so many updates: " + nextUpdates);
}
awaitTotalUpdates(phaser.getPhase() + nextUpdates);
}
boolean awaitFirstUpdate(long timeout, TimeUnit unit) throws InterruptedException
{
try {
phaser.awaitAdvanceInterruptibly(0, timeout, unit);
return true;
}
catch (TimeoutException e) {
return false;
}
}
}

View File

@ -0,0 +1,166 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.concurrent;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.locks.AbstractQueuedLongSynchronizer;
/**
* This synchronization object allows to {@link #increment} a counter without blocking, potentially from multiple
* threads (although in some use cases there is just one incrementer thread), and block in other thread(s), awaiting
* when the count reaches the provided value: see {@link #awaitCount}, or the specified number of events since the
* call: see {@link #awaitNextIncrements}.
*
* This counter wraps around {@link Long#MAX_VALUE} and starts from 0 again, so "next" count should be generally
* obtained by calling {@link #nextCount nextCount(currentCount)} rather than {@code currentCount + 1}.
*
* Memory consistency effects: actions in threads prior to calling {@link #increment} while the count was less than the
* awaited value happen-before actions following count awaiting methods such as {@link #awaitCount}.
*/
public final class ConcurrentAwaitableCounter
{
private static final long MAX_COUNT = Long.MAX_VALUE;
/**
* This method should be called to obtain the next total increment count to be passed to {@link #awaitCount} methods,
* instead of just adding 1 to the previous count, because the count must wrap around {@link Long#MAX_VALUE} and start
* from 0 again.
*/
public static long nextCount(long prevCount)
{
return (prevCount + 1) & MAX_COUNT;
}
private static class Sync extends AbstractQueuedLongSynchronizer
{
@Override
protected long tryAcquireShared(long countWhenWaitStarted)
{
long currentCount = getState();
return compareCounts(currentCount, countWhenWaitStarted) > 0 ? 1 : -1;
}
@Override
protected boolean tryReleaseShared(long increment)
{
long count;
long nextCount;
do {
count = getState();
nextCount = (count + increment) & MAX_COUNT;
} while (!compareAndSetState(count, nextCount));
return true;
}
long getCount()
{
return getState();
}
}
private final Sync sync = new Sync();
/**
* Increment the count. This method could be safely called from concurrent threads.
*/
public void increment()
{
sync.releaseShared(1);
}
/**
* Await until the {@link #increment} is called on this counter object the specified number of times from the creation
* of this counter object.
*/
public void awaitCount(long totalCount) throws InterruptedException
{
checkTotalCount(totalCount);
long currentCount = sync.getCount();
while (compareCounts(totalCount, currentCount) > 0) {
sync.acquireSharedInterruptibly(currentCount);
currentCount = sync.getCount();
}
}
private static void checkTotalCount(long totalCount)
{
if (totalCount < 0) {
throw new AssertionError(
"Total count must always be >= 0, even in the face of overflow. "
+ "The next count should always be obtained by calling ConcurrentAwaitableCounter.nextCount(prevCount), "
+ "not just +1"
);
}
}
/**
* Await until the {@link #increment} is called on this counter object the specified number of times from the creation
* of this counter object, for not longer than the specified period of time. If by this time the target increment
* count is not reached, {@link TimeoutException} is thrown.
*/
public void awaitCount(long totalCount, long timeout, TimeUnit unit) throws InterruptedException, TimeoutException
{
checkTotalCount(totalCount);
long nanos = unit.toNanos(timeout);
long currentCount = sync.getCount();
while (compareCounts(totalCount, currentCount) > 0) {
if (!sync.tryAcquireSharedNanos(currentCount, nanos)) {
throw new TimeoutException();
}
currentCount = sync.getCount();
}
}
private static int compareCounts(long count1, long count2)
{
long diff = (count1 - count2) & MAX_COUNT;
if (diff == 0) {
return 0;
}
return diff < MAX_COUNT / 2 ? 1 : -1;
}
/**
* Somewhat loosely defined wait for "next N increments", because the starting point is not defined from the Java
* Memory Model perspective.
*/
public void awaitNextIncrements(long nextIncrements) throws InterruptedException
{
if (nextIncrements <= 0) {
throw new IllegalArgumentException("nextIncrements is not positive: " + nextIncrements);
}
if (nextIncrements > MAX_COUNT / 4) {
throw new UnsupportedOperationException("Couldn't wait for so many increments: " + nextIncrements);
}
awaitCount((sync.getCount() + nextIncrements) & MAX_COUNT);
}
/**
* The difference between this method and {@link #awaitCount(long, long, TimeUnit)} with argument 1 is that {@code
* awaitFirstIncrement()} returns boolean designating whether the count was await (while waiting for no longer than
* for the specified period of time), while {@code awaitCount()} throws {@link TimeoutException} if the count was not
* awaited.
*/
public boolean awaitFirstIncrement(long timeout, TimeUnit unit) throws InterruptedException
{
return sync.tryAcquireSharedNanos(0, unit.toNanos(timeout));
}
}

View File

@ -90,12 +90,12 @@ class Batch extends AbstractQueuedLongSynchronizer
/**
* Ordering number of this batch, as they filled & emitted in {@link HttpPostEmitter} serially, starting from 0.
* It's a boxed Integer rather than int, because we want to minimize the number of allocations done in
* It's a boxed Long rather than primitive long, because we want to minimize the number of allocations done in
* {@link HttpPostEmitter#onSealExclusive} and so the probability of {@link OutOfMemoryError}.
* @see HttpPostEmitter#onSealExclusive
* @see HttpPostEmitter#concurrentBatch
*/
final Integer batchNumber;
final Long batchNumber;
/**
* The number of events in this batch, needed for event count-based batch emitting.
@ -107,7 +107,7 @@ class Batch extends AbstractQueuedLongSynchronizer
*/
private long firstEventTimestamp = -1;
Batch(HttpPostEmitter emitter, byte[] buffer, int batchNumber)
Batch(HttpPostEmitter emitter, byte[] buffer, long batchNumber)
{
this.emitter = emitter;
this.buffer = buffer;

View File

@ -1,73 +0,0 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.java.util.emitter.core;
import java.util.concurrent.Phaser;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
final class EmittedBatchCounter
{
/**
* Max {@link Phaser}'s phase, specified in it's javadoc. Then it wraps to zero.
*/
private static final int MAX_PHASE = Integer.MAX_VALUE;
static int nextBatchNumber(int prevBatchNumber)
{
return (prevBatchNumber + 1) & MAX_PHASE;
}
private final Phaser phaser = new Phaser(1);
void batchEmitted()
{
phaser.arrive();
}
void awaitBatchEmitted(int batchNumberToAwait, long timeout, TimeUnit unit)
throws InterruptedException, TimeoutException
{
batchNumberToAwait &= MAX_PHASE;
int currentBatch = phaser.getPhase();
checkNotTerminated(currentBatch);
while (comparePhases(batchNumberToAwait, currentBatch) >= 0) {
currentBatch = phaser.awaitAdvanceInterruptibly(currentBatch, timeout, unit);
checkNotTerminated(currentBatch);
}
}
private static int comparePhases(int phase1, int phase2)
{
int diff = (phase1 - phase2) & MAX_PHASE;
if (diff == 0) {
return 0;
}
return diff < MAX_PHASE / 2 ? 1 : -1;
}
private void checkNotTerminated(int phase)
{
if (phase < 0) {
throw new IllegalStateException("Phaser[" + phaser + "] unexpectedly terminated.");
}
}
}

View File

@ -25,6 +25,7 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Throwables;
import com.google.common.primitives.Ints;
import io.druid.concurrent.ConcurrentAwaitableCounter;
import io.druid.java.util.common.ISE;
import io.druid.java.util.common.RetryUtils;
import io.druid.java.util.common.StringUtils;
@ -133,7 +134,7 @@ public class HttpPostEmitter implements Flushable, Closeable, Emitter
*/
private final AtomicInteger approximateLargeEventsToEmitCount = new AtomicInteger();
private final EmittedBatchCounter emittedBatchCounter = new EmittedBatchCounter();
private final ConcurrentAwaitableCounter emittedBatchCounter = new ConcurrentAwaitableCounter();
private final EmittingThread emittingThread;
private final AtomicLong totalEmittedEvents = new AtomicLong();
private final AtomicInteger allocatedBuffers = new AtomicInteger();
@ -177,7 +178,8 @@ public class HttpPostEmitter implements Flushable, Closeable, Emitter
throw new ISE(e, "Bad URL: %s", config.getRecipientBaseUrl());
}
emittingThread = new EmittingThread(config);
concurrentBatch.set(new Batch(this, acquireBuffer(), 0));
long firstBatchNumber = 1;
concurrentBatch.set(new Batch(this, acquireBuffer(), firstBatchNumber));
// lastFillTimeMillis must not be 0, minHttpTimeoutMillis could be.
lastFillTimeMillis = Math.max(config.minHttpTimeoutMillis, 1);
}
@ -331,7 +333,7 @@ public class HttpPostEmitter implements Flushable, Closeable, Emitter
addBatchToEmitQueue(batch);
wakeUpEmittingThread();
if (!isTerminated()) {
int nextBatchNumber = EmittedBatchCounter.nextBatchNumber(batch.batchNumber);
long nextBatchNumber = ConcurrentAwaitableCounter.nextCount(batch.batchNumber);
byte[] newBuffer = acquireBuffer();
if (!concurrentBatch.compareAndSet(batch, new Batch(this, newBuffer, nextBatchNumber))) {
buffersToReuse.add(newBuffer);
@ -345,7 +347,7 @@ public class HttpPostEmitter implements Flushable, Closeable, Emitter
private void tryRecoverCurrentBatch(Integer failedBatchNumber)
{
log.info("Trying to recover currentBatch");
int nextBatchNumber = EmittedBatchCounter.nextBatchNumber(failedBatchNumber);
long nextBatchNumber = ConcurrentAwaitableCounter.nextCount(failedBatchNumber);
byte[] newBuffer = acquireBuffer();
if (concurrentBatch.compareAndSet(failedBatchNumber, new Batch(this, newBuffer, nextBatchNumber))) {
log.info("Successfully recovered currentBatch");
@ -383,7 +385,7 @@ public class HttpPostEmitter implements Flushable, Closeable, Emitter
private void batchFinalized()
{
// Notify HttpPostEmitter.flush(), that the batch is emitted, or failed, or dropped.
emittedBatchCounter.batchEmitted();
emittedBatchCounter.increment();
}
private Batch pollBatchFromEmitQueue()
@ -422,7 +424,7 @@ public class HttpPostEmitter implements Flushable, Closeable, Emitter
// This check doesn't always awaits for this exact batch to be emitted, because another batch could be dropped
// from the queue ahead of this one, in limitBuffersToEmitSize(). But there is no better way currently to wait for
// the exact batch, and it's not that important.
emittedBatchCounter.awaitBatchEmitted(batch.batchNumber, config.getFlushTimeOut(), TimeUnit.MILLISECONDS);
emittedBatchCounter.awaitCount(batch.batchNumber, config.getFlushTimeOut(), TimeUnit.MILLISECONDS);
}
catch (TimeoutException e) {
String message = StringUtils.format("Timed out after [%d] millis during flushing", config.getFlushTimeOut());
@ -923,7 +925,7 @@ public class HttpPostEmitter implements Flushable, Closeable, Emitter
@VisibleForTesting
void waitForEmission(int batchNumber) throws Exception
{
emittedBatchCounter.awaitBatchEmitted(batchNumber, 10, TimeUnit.SECONDS);
emittedBatchCounter.awaitCount(batchNumber, 10, TimeUnit.SECONDS);
}
@VisibleForTesting

View File

@ -0,0 +1,89 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.concurrent;
import org.junit.Assert;
import org.junit.Test;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
public class ConcurrentAwaitableCounterTest
{
@Test(timeout = 1000)
public void smokeTest() throws InterruptedException
{
ConcurrentAwaitableCounter counter = new ConcurrentAwaitableCounter();
CountDownLatch start = new CountDownLatch(1);
CountDownLatch finish = new CountDownLatch(7);
for (int i = 0; i < 2; i++) {
new Thread(() -> {
try {
start.await();
for (int j = 0; j < 10_000; j++) {
counter.increment();
}
finish.countDown();
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
}).start();
}
for (int awaitCount : new int[] {0, 1, 100, 10_000, 20_000}) {
new Thread(() -> {
try {
start.await();
counter.awaitCount(awaitCount);
finish.countDown();
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
}).start();
}
start.countDown();
finish.await();
}
@Test
public void testAwaitFirstUpdate() throws InterruptedException
{
int[] value = new int[1];
ConcurrentAwaitableCounter counter = new ConcurrentAwaitableCounter();
Thread t = new Thread(() -> {
try {
Assert.assertTrue(counter.awaitFirstIncrement(10, TimeUnit.SECONDS));
Assert.assertEquals(1, value[0]);
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
t.start();
Thread.sleep(2_000);
value[0] = 1;
counter.increment();
t.join();
}
}

View File

@ -239,7 +239,7 @@ public class EmitterTest
for (UnitEvent event : events) {
emitter.emit(event);
}
waitForEmission(emitter, 0);
waitForEmission(emitter, 1);
closeNoFlush(emitter);
Assert.assertTrue(httpClient.succeeded());
}
@ -281,7 +281,7 @@ public class EmitterTest
for (UnitEvent event : events) {
emitter.emit(event);
}
waitForEmission(emitter, 0);
waitForEmission(emitter, 1);
closeNoFlush(emitter);
Assert.assertTrue(httpClient.succeeded());
}
@ -297,7 +297,7 @@ public class EmitterTest
httpClient.setGoHandler(GoHandlers.passingHandler(okResponse()).times(1));
emitter.emit(new UnitEvent("test", 3));
waitForEmission(emitter, 0);
waitForEmission(emitter, 1);
httpClient.setGoHandler(GoHandlers.failingHandler());
emitter.emit(new UnitEvent("test", 4));
@ -337,7 +337,7 @@ public class EmitterTest
timeWaited < timeBetweenEmissions * 2
);
waitForEmission(emitter, 0);
waitForEmission(emitter, 1);
final CountDownLatch thisLatch = new CountDownLatch(1);
httpClient.setGoHandler(
@ -362,7 +362,7 @@ public class EmitterTest
timeWaited < timeBetweenEmissions * 2
);
waitForEmission(emitter, 1);
waitForEmission(emitter, 2);
closeNoFlush(emitter);
Assert.assertTrue("httpClient.succeeded()", httpClient.succeeded());
}
@ -388,7 +388,7 @@ public class EmitterTest
);
emitter.emit(event1);
emitter.flush();
waitForEmission(emitter, 0);
waitForEmission(emitter, 1);
Assert.assertTrue(httpClient.succeeded());
// Failed to emit the first event.
@ -407,7 +407,7 @@ public class EmitterTest
emitter.emit(event2);
emitter.flush();
waitForEmission(emitter, 1);
waitForEmission(emitter, 2);
closeNoFlush(emitter);
// Failed event is emitted inside emitter thread, there is no other way to wait for it other than joining the
// emitterThread
@ -461,7 +461,7 @@ public class EmitterTest
emitter.emit(event);
}
emitter.flush();
waitForEmission(emitter, 0);
waitForEmission(emitter, 1);
closeNoFlush(emitter);
Assert.assertTrue(httpClient.succeeded());
}
@ -512,11 +512,11 @@ public class EmitterTest
for (UnitEvent event : events) {
emitter.emit(event);
}
waitForEmission(emitter, 0);
waitForEmission(emitter, 1);
Assert.assertEquals(2, emitter.getTotalEmittedEvents());
emitter.flush();
waitForEmission(emitter, 1);
waitForEmission(emitter, 2);
Assert.assertEquals(4, emitter.getTotalEmittedEvents());
closeNoFlush(emitter);
Assert.assertTrue(httpClient.succeeded());
@ -571,7 +571,7 @@ public class EmitterTest
for (UnitEvent event : events) {
emitter.emit(event);
}
waitForEmission(emitter, 0);
waitForEmission(emitter, 1);
closeNoFlush(emitter);
Assert.assertTrue(httpClient.succeeded());
}