[7.x] Rely on the computeIfAbsent logic to prevent duplicated compilation of scripts (#55467) (#58123)

Instead of serializing compilation using a plain lock / mutex combined with a double check, rely on the computeIfAbsent logic to prevent duplicated compilation of scripts. Made checkCompilationLimit to be thread-safe and lock free.

Backport: 865acad

Co-authored-by: Michael Bischoff <michael.bischoff@elastic.co>
This commit is contained in:
Stuart Tettemer 2020-06-15 12:01:22 -06:00 committed by GitHub
parent e268a89ef2
commit 71a42dbde9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 58 deletions

View File

@ -32,6 +32,8 @@ import org.elasticsearch.common.unit.TimeValue;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;
/**
* Script cache and compilation rate limiter.
@ -44,12 +46,7 @@ public class ScriptCache {
private final Cache<CacheKey, Object> cache;
private final ScriptMetrics scriptMetrics;
private final Object lock = new Object();
// Mutable fields, visible for tests
long lastInlineCompileTime;
double scriptsPerTimeWindow;
final AtomicReference<TokenBucketState> tokenBucketState;
// Cache settings or derived from settings
final int cacheSize;
@ -81,11 +78,9 @@ public class ScriptCache {
this.cache = cacheBuilder.removalListener(new ScriptCacheRemovalListener()).build();
this.rate = maxCompilationRate;
this.scriptsPerTimeWindow = this.rate.v1();
this.compilesAllowedPerNano = ((double) rate.v1()) / rate.v2().nanos();
this.lastInlineCompileTime = System.nanoTime();
this.scriptMetrics = new ScriptMetrics();
this.tokenBucketState = new AtomicReference<TokenBucketState>(new TokenBucketState(this.rate.v1()));
}
<FactoryType> FactoryType compile(
@ -98,47 +93,43 @@ public class ScriptCache {
) {
String lang = scriptEngine.getType();
CacheKey cacheKey = new CacheKey(lang, idOrCode, context.name, options);
Object compiledScript = cache.get(cacheKey);
if (compiledScript != null) {
return context.factoryClazz.cast(compiledScript);
}
// Synchronize so we don't compile scripts many times during multiple shards all compiling a script
synchronized (lock) {
// Retrieve it again in case it has been put by a different thread
compiledScript = cache.get(cacheKey);
if (compiledScript == null) {
try {
// Either an un-cached inline script or indexed script
// If the script type is inline the name will be the same as the code for identification in exceptions
// but give the script engine the chance to be better, give it separate name + source code
// for the inline case, then its anonymous: null.
if (logger.isTraceEnabled()) {
logger.trace("context [{}]: compiling script, type: [{}], lang: [{}], options: [{}]", context.name, type,
lang, options);
}
// Check whether too many compilations have happened
checkCompilationLimit();
compiledScript = scriptEngine.compile(id, idOrCode, context, options);
} catch (ScriptException good) {
// TODO: remove this try-catch completely, when all script engines have good exceptions!
throw good; // its already good
} catch (Exception exception) {
throw new GeneralScriptException("Failed to compile " + type + " script [" + id + "] using lang [" + lang + "]",
exception);
// Relying on computeIfAbsent to avoid multiple threads from compiling the same script
try {
return context.factoryClazz.cast(cache.computeIfAbsent(cacheKey, key -> {
// Either an un-cached inline script or indexed script
// If the script type is inline the name will be the same as the code for identification in exceptions
// but give the script engine the chance to be better, give it separate name + source code
// for the inline case, then its anonymous: null.
if (logger.isTraceEnabled()) {
logger.trace("context [{}]: compiling script, type: [{}], lang: [{}], options: [{}]", context.name, type,
lang, options);
}
// Check whether too many compilations have happened
checkCompilationLimit();
Object compiledScript = scriptEngine.compile(id, idOrCode, context, options);
// Since the cache key is the script content itself we don't need to
// invalidate/check the cache if an indexed script changes.
scriptMetrics.onCompilation();
cache.put(cacheKey, compiledScript);
return compiledScript;
}));
} catch (ExecutionException executionException) {
Throwable cause = executionException.getCause();
if (cause instanceof ScriptException) {
throw (ScriptException) cause;
} else if (cause instanceof Exception) {
throw new GeneralScriptException("Failed to compile " + type + " script [" + id + "] using lang [" + lang + "]", cause);
} else {
rethrow(cause);
throw new AssertionError(cause);
}
}
}
return context.factoryClazz.cast(compiledScript);
/** Hack to rethrow unknown Exceptions from compile: */
@SuppressWarnings("unchecked")
static <T extends Throwable> void rethrow(Throwable t) throws T {
throw (T) t;
}
public ScriptStats stats() {
@ -159,21 +150,26 @@ public class ScriptCache {
return;
}
long now = System.nanoTime();
long timePassed = now - lastInlineCompileTime;
lastInlineCompileTime = now;
TokenBucketState tokenBucketState = this.tokenBucketState.updateAndGet(current -> {
long now = System.nanoTime();
long timePassed = now - current.lastInlineCompileTime;
double scriptsPerTimeWindow = current.availableTokens + (timePassed) * compilesAllowedPerNano;
scriptsPerTimeWindow += (timePassed) * compilesAllowedPerNano;
// It's been over the time limit anyway, readjust the bucket to be level
if (scriptsPerTimeWindow > rate.v1()) {
scriptsPerTimeWindow = rate.v1();
}
// It's been over the time limit anyway, readjust the bucket to be level
if (scriptsPerTimeWindow > rate.v1()) {
scriptsPerTimeWindow = rate.v1();
}
// If there is enough tokens in the bucket, allow the request and decrease the tokens by 1
if (scriptsPerTimeWindow >= 1) {
scriptsPerTimeWindow -= 1.0;
return new TokenBucketState(now, scriptsPerTimeWindow, true);
} else {
return new TokenBucketState(now, scriptsPerTimeWindow, false);
}
});
// If there is enough tokens in the bucket, allow the request and decrease the tokens by 1
if (scriptsPerTimeWindow >= 1) {
scriptsPerTimeWindow -= 1.0;
} else {
if(!tokenBucketState.tokenSuccessfullyTaken) {
scriptMetrics.onCompilationLimit();
// Otherwise reject the request
throw new CircuitBreakingException("[script] Too many dynamic script compilations within, max: [" +
@ -231,4 +227,20 @@ public class ScriptCache {
return Objects.hash(lang, idOrCode, context, options);
}
}
static class TokenBucketState {
public final long lastInlineCompileTime;
public final double availableTokens;
public final boolean tokenSuccessfullyTaken;
private TokenBucketState(double availableTokens) {
this(System.nanoTime(), availableTokens, false);
}
private TokenBucketState(long lastInlineCompileTime, double availableTokens, boolean tokenSuccessfullyTaken) {
this.lastInlineCompileTime = lastInlineCompileTime;
this.availableTokens = availableTokens;
this.tokenSuccessfullyTaken = tokenSuccessfullyTaken;
}
}
}

View File

@ -59,12 +59,12 @@ public class ScriptCacheTests extends ESTestCase {
final TimeValue expire = ScriptService.SCRIPT_GENERAL_CACHE_EXPIRE_SETTING.get(Settings.EMPTY);
String settingName = ScriptService.SCRIPT_GENERAL_MAX_COMPILATIONS_RATE_SETTING.getKey();
ScriptCache cache = new ScriptCache(size, expire, ScriptCache.UNLIMITED_COMPILATION_RATE, settingName);
long lastInlineCompileTime = cache.lastInlineCompileTime;
double scriptsPerTimeWindow = cache.scriptsPerTimeWindow;
ScriptCache.TokenBucketState initialState = cache.tokenBucketState.get();
for(int i=0; i < 3000; i++) {
cache.checkCompilationLimit();
assertEquals(lastInlineCompileTime, cache.lastInlineCompileTime);
assertEquals(scriptsPerTimeWindow, cache.scriptsPerTimeWindow, 0.0); // delta of 0.0 because it should never change
ScriptCache.TokenBucketState currentState = cache.tokenBucketState.get();
assertEquals(initialState.lastInlineCompileTime, currentState.lastInlineCompileTime);
assertEquals(initialState.availableTokens, currentState.availableTokens, 0.0); // delta of 0.0 because it should never change
}
}
}