diff --git a/src/main/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealm.java b/src/main/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealm.java index a7ff0221830..79a69aa4c91 100644 --- a/src/main/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealm.java +++ b/src/main/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealm.java @@ -16,7 +16,6 @@ import org.elasticsearch.shield.authc.AuthenticationToken; import org.elasticsearch.shield.authc.Realm; import org.elasticsearch.transport.TransportMessage; -import java.util.Arrays; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -24,8 +23,9 @@ import java.util.concurrent.TimeUnit; public abstract class CachingUsernamePasswordRealm extends AbstractComponent implements Realm { private static final TimeValue DEFAULT_TTL = TimeValue.timeValueHours(1); + private static final int DEFAULT_MAX_USERS = 100000; //100k users - private final Cache cache; + private final Cache cache; protected CachingUsernamePasswordRealm(Settings settings) { super(settings); @@ -33,7 +33,7 @@ public abstract class CachingUsernamePasswordRealm extends AbstractComponent imp if (ttl.millis() > 0) { cache = CacheBuilder.newBuilder() .expireAfterWrite(ttl.getMillis(), TimeUnit.MILLISECONDS) - .maximumSize(settings.getAsInt("cache.max_users", -1)) + .maximumSize(settings.getAsInt("cache.max_users", DEFAULT_MAX_USERS)) .build(); } else { cache = null; @@ -62,54 +62,59 @@ public abstract class CachingUsernamePasswordRealm extends AbstractComponent imp } } + /** + * If the user exists in the cache (keyed by the principle name), then the password is validated + * against a hash also stored in the cache. Otherwise the subclass authenticates the user via + * doAuthenticate + * + * @param token The authentication token + * @return an authenticated user with roles + */ @Override public User authenticate(final UsernamePasswordToken token) { if (cache == null) { return doAuthenticate(token); } - try { - return cache.get(new CacheKey(token), new Callable() { - @Override - public User call() throws Exception { - return doAuthenticate(token); + Callable callback = new Callable() { + @Override + public UserWithHash call() throws Exception { + User user = doAuthenticate(token); + if (user == null) { + throw new AuthenticationException("Could not authenticate ['" + token.principal() + "]"); } - }); + return new UserWithHash(user, token.credentials()); + } + }; + + try { + UserWithHash userWithHash = cache.get(token.principal(), callback); + if (userWithHash.verify(token.credentials())) { + return userWithHash.user; + } + //this handles when a user's password has changed: + expire(token.principal()); + userWithHash = cache.get(token.principal(), callback); + return userWithHash.user; + } catch (ExecutionException ee) { - throw new AuthenticationException("Could not authenticate ['" + token.principal() + "]", ee); + logger.warn("Could not authenticate ['" + token.principal() + "]", ee); + return null; } } protected abstract User doAuthenticate(UsernamePasswordToken token); - static class CacheKey { - - private final String username; - private final char[] passwdHash; - - CacheKey(UsernamePasswordToken token) { - this.username = token.principal(); - this.passwdHash = Hasher.HTPASSWD.hash(token.credentials()); + public static class UserWithHash { + User user; + char[] hash; + public UserWithHash(User user, char[] password){ + this.user = user; + this.hash = Hasher.HTPASSWD.hash(password); } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - CacheKey cacheKey = (CacheKey) o; - - if (!Arrays.equals(passwdHash, cacheKey.passwdHash)) return false; - if (!username.equals(cacheKey.username)) return false; - - return true; - } - - @Override - public int hashCode() { - int result = username.hashCode(); - result = 31 * result + Arrays.hashCode(passwdHash); - return result; + public boolean verify(char[] password){ + return Hasher.HTPASSWD.verify(password, hash); } } } diff --git a/src/test/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealmTests.java b/src/test/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealmTests.java new file mode 100644 index 00000000000..e09b1bd9fbf --- /dev/null +++ b/src/test/java/org/elasticsearch/shield/authc/support/CachingUsernamePasswordRealmTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.shield.authc.support; + +import org.elasticsearch.common.settings.ImmutableSettings; +import org.elasticsearch.shield.User; +import org.junit.Test; + +import java.util.concurrent.atomic.AtomicInteger; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.assertThat; + +public class CachingUsernamePasswordRealmTests { + public static class AlwaysAuthenticateCachingRealm extends CachingUsernamePasswordRealm { + public AlwaysAuthenticateCachingRealm() { + super(ImmutableSettings.EMPTY); + } + public final AtomicInteger INVOCATION_COUNTER = new AtomicInteger(0); + @Override protected User doAuthenticate(UsernamePasswordToken token) { + INVOCATION_COUNTER.incrementAndGet(); + return new User.Simple(token.principal(), "testRole1", "testRole2"); + } + + @Override public String type() { return "test"; }; + } + + + @Test + public void testCache(){ + AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(); + char[] pass = "pass".toCharArray(); + realm.authenticate(new UsernamePasswordToken("a", pass)); + realm.authenticate(new UsernamePasswordToken("b", pass)); + realm.authenticate(new UsernamePasswordToken("c", pass)); + + assertThat(realm.INVOCATION_COUNTER.intValue(), is(3)); + realm.authenticate(new UsernamePasswordToken("a", pass)); + realm.authenticate(new UsernamePasswordToken("b", pass)); + realm.authenticate(new UsernamePasswordToken("c", pass)); + + assertThat(realm.INVOCATION_COUNTER.intValue(), is(3)); + } + + @Test + public void testCache_changePassword(){ + AlwaysAuthenticateCachingRealm realm = new AlwaysAuthenticateCachingRealm(); + + String user = "testUser"; + char[] pass1 = "pass".toCharArray(); + char[] pass2 = "password".toCharArray(); + + realm.authenticate(new UsernamePasswordToken(user, pass1)); + realm.authenticate(new UsernamePasswordToken(user, pass1)); + + assertThat(realm.INVOCATION_COUNTER.intValue(), is(1)); + + realm.authenticate(new UsernamePasswordToken(user, pass2)); + realm.authenticate(new UsernamePasswordToken(user, pass2)); + + assertThat(realm.INVOCATION_COUNTER.intValue(), is(2)); + } +}