Revert "Support concurrent refresh of refresh tokens (#39559)"

This reverts commit e2599214e0.
This commit is contained in:
Tanguy Leroux 2019-03-01 17:54:19 +01:00
parent 39a401b827
commit 0c6b7cfb77
12 changed files with 302 additions and 707 deletions

View File

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.core.security.authc.support; package org.elasticsearch.xpack.core.security.authc.support;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
@ -33,9 +32,10 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
private final List<String> invalidatedTokens; private final List<String> invalidatedTokens;
private final List<String> previouslyInvalidatedTokens; private final List<String> previouslyInvalidatedTokens;
private final List<ElasticsearchException> errors; private final List<ElasticsearchException> errors;
private final int attemptCount;
public TokensInvalidationResult(List<String> invalidatedTokens, List<String> previouslyInvalidatedTokens, public TokensInvalidationResult(List<String> invalidatedTokens, List<String> previouslyInvalidatedTokens,
@Nullable List<ElasticsearchException> errors) { @Nullable List<ElasticsearchException> errors, int attemptCount) {
Objects.requireNonNull(invalidatedTokens, "invalidated_tokens must be provided"); Objects.requireNonNull(invalidatedTokens, "invalidated_tokens must be provided");
this.invalidatedTokens = invalidatedTokens; this.invalidatedTokens = invalidatedTokens;
Objects.requireNonNull(previouslyInvalidatedTokens, "previously_invalidated_tokens must be provided"); Objects.requireNonNull(previouslyInvalidatedTokens, "previously_invalidated_tokens must be provided");
@ -45,19 +45,18 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
} else { } else {
this.errors = Collections.emptyList(); this.errors = Collections.emptyList();
} }
this.attemptCount = attemptCount;
} }
public TokensInvalidationResult(StreamInput in) throws IOException { public TokensInvalidationResult(StreamInput in) throws IOException {
this.invalidatedTokens = in.readStringList(); this.invalidatedTokens = in.readStringList();
this.previouslyInvalidatedTokens = in.readStringList(); this.previouslyInvalidatedTokens = in.readStringList();
this.errors = in.readList(StreamInput::readException); this.errors = in.readList(StreamInput::readException);
if (in.getVersion().before(Version.V_7_1_0)) { this.attemptCount = in.readVInt();
in.readVInt();
}
} }
public static TokensInvalidationResult emptyResult() { public static TokensInvalidationResult emptyResult() {
return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0);
} }
@ -73,6 +72,10 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
return errors; return errors;
} }
public int getAttemptCount() {
return attemptCount;
}
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject() builder.startObject()
@ -97,8 +100,6 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
out.writeStringCollection(invalidatedTokens); out.writeStringCollection(invalidatedTokens);
out.writeStringCollection(previouslyInvalidatedTokens); out.writeStringCollection(previouslyInvalidatedTokens);
out.writeCollection(errors, StreamOutput::writeException); out.writeCollection(errors, StreamOutput::writeException);
if (out.getVersion().before(Version.V_7_1_0)) { out.writeVInt(attemptCount);
out.writeVInt(5);
}
} }
} }

View File

@ -199,13 +199,6 @@
"refreshed" : { "refreshed" : {
"type" : "boolean" "type" : "boolean"
}, },
"refresh_time": {
"type": "date",
"format": "epoch_millis"
},
"superseded_by": {
"type": "keyword"
},
"invalidated" : { "invalidated" : {
"type" : "boolean" "type" : "boolean"
}, },

View File

@ -29,7 +29,8 @@ public class InvalidateTokenResponseTests extends ESTestCase {
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(generateRandomStringArray(20, 15, false)), Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")),
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2")))); new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))),
randomIntBetween(0, 5));
InvalidateTokenResponse response = new InvalidateTokenResponse(result); InvalidateTokenResponse response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) { try (BytesStreamOutput output = new BytesStreamOutput()) {
response.writeTo(output); response.writeTo(output);
@ -46,7 +47,8 @@ public class InvalidateTokenResponseTests extends ESTestCase {
} }
result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(generateRandomStringArray(20, 15, false)), Collections.emptyList()); Arrays.asList(generateRandomStringArray(20, 15, false)),
Collections.emptyList(), randomIntBetween(0, 5));
response = new InvalidateTokenResponse(result); response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) { try (BytesStreamOutput output = new BytesStreamOutput()) {
response.writeTo(output); response.writeTo(output);
@ -66,7 +68,8 @@ public class InvalidateTokenResponseTests extends ESTestCase {
List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false)); List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false));
TokensInvalidationResult result = new TokensInvalidationResult(invalidatedTokens, previouslyInvalidatedTokens, TokensInvalidationResult result = new TokensInvalidationResult(invalidatedTokens, previouslyInvalidatedTokens,
Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")),
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2")))); new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))),
randomIntBetween(0, 5));
InvalidateTokenResponse response = new InvalidateTokenResponse(result); InvalidateTokenResponse response = new InvalidateTokenResponse(result);
XContentBuilder builder = XContentFactory.jsonBuilder(); XContentBuilder builder = XContentFactory.jsonBuilder();
response.toXContent(builder, ToXContent.EMPTY_PARAMS); response.toXContent(builder, ToXContent.EMPTY_PARAMS);

View File

@ -63,7 +63,7 @@ public final class TransportSamlAuthenticateAction extends HandledTransportActio
final Map<String, Object> tokenMeta = (Map<String, Object>) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA); final Map<String, Object> tokenMeta = (Map<String, Object>) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA);
tokenService.createUserToken(authentication, originatingAuthentication, tokenService.createUserToken(authentication, originatingAuthentication,
ActionListener.wrap(tuple -> { ActionListener.wrap(tuple -> {
final String tokenString = tokenService.getAccessTokenAsString(tuple.v1()); final String tokenString = tokenService.getUserTokenString(tuple.v1());
final TimeValue expiresIn = tokenService.getExpirationDelay(); final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse( listener.onResponse(
new SamlAuthenticateResponse(authentication.getUser().principal(), tokenString, tuple.v2(), expiresIn)); new SamlAuthenticateResponse(authentication.getUser().principal(), tokenString, tuple.v2(), expiresIn));

View File

@ -89,7 +89,7 @@ public final class TransportCreateTokenAction extends HandledTransportAction<Cre
boolean includeRefreshToken, ActionListener<CreateTokenResponse> listener) { boolean includeRefreshToken, ActionListener<CreateTokenResponse> listener) {
try { try {
tokenService.createUserToken(authentication, originatingAuth, ActionListener.wrap(tuple -> { tokenService.createUserToken(authentication, originatingAuth, ActionListener.wrap(tuple -> {
final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1()); final String tokenStr = tokenService.getUserTokenString(tuple.v1());
final String scope = getResponseScopeValue(request.getScope()); final String scope = getResponseScopeValue(request.getScope());
final CreateTokenResponse response = final CreateTokenResponse response =

View File

@ -31,7 +31,7 @@ public class TransportRefreshTokenAction extends HandledTransportAction<CreateTo
@Override @Override
protected void doExecute(Task task, CreateTokenRequest request, ActionListener<CreateTokenResponse> listener) { protected void doExecute(Task task, CreateTokenRequest request, ActionListener<CreateTokenResponse> listener) {
tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> { tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> {
final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1()); final String tokenStr = tokenService.getUserTokenString(tuple.v1());
final String scope = getResponseScopeValue(request.getScope()); final String scope = getResponseScopeValue(request.getScope());
final CreateTokenResponse response = final CreateTokenResponse response =

View File

@ -242,7 +242,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
tokenService.createUserToken(authentication, authentication, future, tokenMetaData, true); tokenService.createUserToken(authentication, authentication, future, tokenMetaData, true);
final UserToken userToken = future.actionGet().v1(); final UserToken userToken = future.actionGet().v1();
mockGetTokenFromId(userToken, false, client); mockGetTokenFromId(userToken, false, client);
final String tokenString = tokenService.getAccessTokenAsString(userToken); final String tokenString = tokenService.getUserTokenString(userToken);
final SamlLogoutRequest request = new SamlLogoutRequest(); final SamlLogoutRequest request = new SamlLogoutRequest();
request.setToken(tokenString); request.setToken(tokenString);

View File

@ -1109,7 +1109,7 @@ public class AuthenticationServiceTests extends ESTestCase {
Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null);
tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true); tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true);
} }
String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1()); String token = tokenService.getUserTokenString(tokenFuture.get().v1());
when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE)); when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE));
mockGetTokenFromId(tokenFuture.get().v1(), false, client); mockGetTokenFromId(tokenFuture.get().v1(), false, client);
when(securityIndex.isAvailable()).thenReturn(true); when(securityIndex.isAvailable()).thenReturn(true);
@ -1192,7 +1192,7 @@ public class AuthenticationServiceTests extends ESTestCase {
Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null);
tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true); tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true);
} }
String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1()); String token = tokenService.getUserTokenString(tokenFuture.get().v1());
mockGetTokenFromId(tokenFuture.get().v1(), true, client); mockGetTokenFromId(tokenFuture.get().v1(), true, client);
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
((Runnable) invocationOnMock.getArguments()[1]).run(); ((Runnable) invocationOnMock.getArguments()[1]).run();

View File

@ -5,14 +5,11 @@
*/ */
package org.elasticsearch.xpack.security.authc; package org.elasticsearch.xpack.security.authc;
import org.apache.directory.api.util.Strings;
import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ack.ClusterStateUpdateResponse; import org.elasticsearch.cluster.ack.ClusterStateUpdateResponse;
import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.SecureString;
@ -26,7 +23,6 @@ import org.elasticsearch.test.SecuritySettingsSource;
import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest;
import org.elasticsearch.xpack.core.security.action.token.CreateTokenResponse; import org.elasticsearch.xpack.core.security.action.token.CreateTokenResponse;
import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest;
import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenResponse; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenResponse;
@ -42,13 +38,7 @@ import org.junit.Before;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -340,7 +330,7 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
assertEquals("token has been invalidated", e.getHeader("error_description").get(0)); assertEquals("token has been invalidated", e.getHeader("error_description").get(0));
} }
public void testRefreshingMultipleTimesFails() throws Exception { public void testRefreshingMultipleTimes() {
Client client = client().filterWithHeader(Collections.singletonMap("Authorization", Client client = client().filterWithHeader(Collections.singletonMap("Authorization",
UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME,
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)));
@ -353,101 +343,12 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
assertNotNull(createTokenResponse.getRefreshToken()); assertNotNull(createTokenResponse.getRefreshToken());
CreateTokenResponse refreshResponse = securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get(); CreateTokenResponse refreshResponse = securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get();
assertNotNull(refreshResponse); assertNotNull(refreshResponse);
// We now have two documents, the original(now refreshed) token doc and the new one with the new access doc
AtomicReference<String> docId = new AtomicReference<>();
assertBusy(() -> {
SearchResponse searchResponse = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
.setSource(SearchSourceBuilder.searchSource()
.query(QueryBuilders.boolQuery()
.must(QueryBuilders.termQuery("doc_type", "token"))
.must(QueryBuilders.termQuery("refresh_token.refreshed", "true"))))
.setSize(1)
.setTerminateAfter(1)
.get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo(1L));
docId.set(searchResponse.getHits().getAt(0).getId());
});
// hack doc to modify the refresh time to 50 seconds ago so that we don't hit the lenient refresh case
Instant refreshed = Instant.now();
Instant aWhileAgo = refreshed.minus(50L, ChronoUnit.SECONDS);
assertTrue(Instant.now().isAfter(aWhileAgo));
UpdateResponse updateResponse = client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, "doc", docId.get())
.setDoc("refresh_token", Collections.singletonMap("refresh_time", aWhileAgo.toEpochMilli()))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.setFetchSource("refresh_token", Strings.EMPTY_STRING)
.get();
assertNotNull(updateResponse);
Map<String, Object> refreshTokenMap = (Map<String, Object>) updateResponse.getGetResult().sourceAsMap().get("refresh_token");
assertTrue(
Instant.ofEpochMilli((long) refreshTokenMap.get("refresh_time")).isBefore(Instant.now().minus(30L, ChronoUnit.SECONDS)));
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
() -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get()); () -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get());
assertEquals("invalid_grant", e.getMessage()); assertEquals("invalid_grant", e.getMessage());
assertEquals(RestStatus.BAD_REQUEST, e.status()); assertEquals(RestStatus.BAD_REQUEST, e.status());
assertEquals("token has already been refreshed more than 30 seconds in the past", e.getHeader("error_description").get(0)); assertEquals("token has already been refreshed", e.getHeader("error_description").get(0));
}
public void testRefreshingMultipleTimesWithinWindowSucceeds() throws Exception {
Client client = client().filterWithHeader(Collections.singletonMap("Authorization",
UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME,
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)));
SecurityClient securityClient = new SecurityClient(client);
Set<String> refreshTokens = new HashSet<>();
Set<String> accessTokens = new HashSet<>();
CreateTokenResponse createTokenResponse = securityClient.prepareCreateToken()
.setGrantType("password")
.setUsername(SecuritySettingsSource.TEST_USER_NAME)
.setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()))
.get();
assertNotNull(createTokenResponse.getRefreshToken());
final int numberOfProcessors = Runtime.getRuntime().availableProcessors();
final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3);
List<Thread> threads = new ArrayList<>(numberOfThreads);
final CountDownLatch readyLatch = new CountDownLatch(numberOfThreads + 1);
final CountDownLatch completedLatch = new CountDownLatch(numberOfThreads);
AtomicBoolean failed = new AtomicBoolean();
for (int i = 0; i < numberOfThreads; i++) {
threads.add(new Thread(() -> {
// Each thread gets its own client so that more than one nodes will be hit
Client threadClient = client().filterWithHeader(Collections.singletonMap("Authorization",
UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME,
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)));
SecurityClient threadSecurityClient = new SecurityClient(threadClient);
CreateTokenRequest refreshRequest =
threadSecurityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).request();
readyLatch.countDown();
try {
readyLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
completedLatch.countDown();
return;
}
threadSecurityClient.refreshToken(refreshRequest, ActionListener.wrap(result -> {
accessTokens.add(result.getTokenString());
refreshTokens.add(result.getRefreshToken());
logger.info("received access token [{}] and refresh token [{}]", result.getTokenString(), result.getRefreshToken());
completedLatch.countDown();
}, e -> {
failed.set(true);
completedLatch.countDown();
logger.error("caught exception", e);
}));
}));
}
for (Thread thread : threads) {
thread.start();
}
readyLatch.countDown();
readyLatch.await();
for (Thread thread : threads) {
thread.join();
}
completedLatch.await();
assertThat(failed.get(), equalTo(false));
assertThat(accessTokens.size(), equalTo(1));
assertThat(refreshTokens.size(), equalTo(1));
} }
public void testRefreshAsDifferentUser() { public void testRefreshAsDifferentUser() {

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.security.authc; package org.elasticsearch.xpack.security.authc;
import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NoShardAvailableActionException; import org.elasticsearch.action.NoShardAvailableActionException;
import org.elasticsearch.action.get.GetAction; import org.elasticsearch.action.get.GetAction;
@ -24,8 +23,6 @@ import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.OutputStreamStreamOutput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
@ -54,11 +51,7 @@ import org.junit.AfterClass;
import org.junit.Before; import org.junit.Before;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
@ -68,7 +61,6 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import javax.crypto.CipherOutputStream;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import static java.time.Clock.systemUTC; import static java.time.Clock.systemUTC;
@ -159,7 +151,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getAccessTokenAsString(token)); requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getUserTokenString(token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -206,7 +198,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -227,10 +219,10 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true); tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true);
final UserToken newToken = newTokenFuture.get().v1(); final UserToken newToken = newTokenFuture.get().v1();
assertNotNull(newToken); assertNotNull(newToken);
assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token)); assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token));
requestContext = new ThreadContext(Settings.EMPTY); requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken));
mockGetTokenFromId(newToken, false); mockGetTokenFromId(newToken, false);
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
@ -255,7 +247,7 @@ public class TokenServiceTests extends ESTestCase {
rotateKeys(tokenService); rotateKeys(tokenService);
} }
TokenService otherTokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, TokenService otherTokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex,
clusterService); clusterService);
otherTokenService.refreshMetaData(tokenService.getTokenMetaData()); otherTokenService.refreshMetaData(tokenService.getTokenMetaData());
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>(); PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
@ -266,7 +258,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
otherTokenService.getAndValidateToken(requestContext, future); otherTokenService.getAndValidateToken(requestContext, future);
@ -297,7 +289,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -324,7 +316,7 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true); tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true);
final UserToken newToken = newTokenFuture.get().v1(); final UserToken newToken = newTokenFuture.get().v1();
assertNotNull(newToken); assertNotNull(newToken);
assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token)); assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token));
metaData = tokenService.pruneKeys(1); metaData = tokenService.pruneKeys(1);
tokenService.refreshMetaData(metaData); tokenService.refreshMetaData(metaData);
@ -337,7 +329,7 @@ public class TokenServiceTests extends ESTestCase {
} }
requestContext = new ThreadContext(Settings.EMPTY); requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken));
mockGetTokenFromId(newToken, false); mockGetTokenFromId(newToken, false);
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -359,7 +351,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -370,8 +362,8 @@ public class TokenServiceTests extends ESTestCase {
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
// verify a second separate token service with its own passphrase cannot verify // verify a second separate token service with its own passphrase cannot verify
TokenService anotherService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, TokenService anotherService = new TokenService(Settings.EMPTY, systemUTC(), client, securityIndex,
clusterService); clusterService);
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
anotherService.getAndValidateToken(requestContext, future); anotherService.getAndValidateToken(requestContext, future);
assertNull(future.get()); assertNull(future.get());
@ -385,10 +377,10 @@ public class TokenServiceTests extends ESTestCase {
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>(); PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true); tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
UserToken token = tokenFuture.get().v1(); UserToken token = tokenFuture.get().v1();
assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue()); assertThat(tokenService.getUserTokenString(token), notNullValue());
tokenService.clearActiveKeyCache(); tokenService.clearActiveKeyCache();
assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue()); assertThat(tokenService.getUserTokenString(token), notNullValue());
} }
public void testInvalidatedToken() throws Exception { public void testInvalidatedToken() throws Exception {
@ -403,7 +395,7 @@ public class TokenServiceTests extends ESTestCase {
mockGetTokenFromId(token, true); mockGetTokenFromId(token, true);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -457,7 +449,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
// the clock is still frozen, so the cookie should be valid // the clock is still frozen, so the cookie should be valid
@ -567,7 +559,7 @@ public class TokenServiceTests extends ESTestCase {
//mockGetTokenFromId(token, false); //mockGetTokenFromId(token, false);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
ActionListener<GetResponse> listener = (ActionListener<GetResponse>) invocationOnMock.getArguments()[1]; ActionListener<GetResponse> listener = (ActionListener<GetResponse>) invocationOnMock.getArguments()[1];
@ -606,7 +598,7 @@ public class TokenServiceTests extends ESTestCase {
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
UserToken expired = new UserToken(authentication, Instant.now().minus(3L, ChronoUnit.DAYS)); UserToken expired = new UserToken(authentication, Instant.now().minus(3L, ChronoUnit.DAYS));
mockGetTokenFromId(expired, false); mockGetTokenFromId(expired, false);
String userTokenString = tokenService.getAccessTokenAsString(expired); String userTokenString = tokenService.getUserTokenString(expired);
PlainActionFuture<Tuple<Authentication, Map<String, Object>>> authFuture = new PlainActionFuture<>(); PlainActionFuture<Tuple<Authentication, Map<String, Object>>> authFuture = new PlainActionFuture<>();
tokenService.getAuthenticationAndMetaData(userTokenString, authFuture); tokenService.getAuthenticationAndMetaData(userTokenString, authFuture);
Authentication retrievedAuth = authFuture.actionGet().v1(); Authentication retrievedAuth = authFuture.actionGet().v1();
@ -647,28 +639,4 @@ public class TokenServiceTests extends ESTestCase {
assertEquals(expected.getMetadata(), result.getMetadata()); assertEquals(expected.getMetadata(), result.getMetadata());
assertEquals(AuthenticationType.TOKEN, result.getAuthenticationType()); assertEquals(AuthenticationType.TOKEN, result.getAuthenticationType());
} }
protected String getDeprecatedAccessTokenString(TokenService tokenService, UserToken userToken) throws IOException,
GeneralSecurityException {
try (ByteArrayOutputStream os = new ByteArrayOutputStream(TokenService.MINIMUM_BASE64_BYTES);
OutputStream base64 = Base64.getEncoder().wrap(os);
StreamOutput out = new OutputStreamStreamOutput(base64)) {
out.setVersion(Version.V_7_0_0);
TokenService.KeyAndCache keyAndCache = tokenService.getActiveKeyCache();
Version.writeVersion(Version.V_7_0_0, out);
out.writeByteArray(keyAndCache.getSalt().bytes);
out.writeByteArray(keyAndCache.getKeyHash().bytes);
final byte[] initializationVector = tokenService.getNewInitializationVector();
out.writeByteArray(initializationVector);
try (CipherOutputStream encryptedOutput =
new CipherOutputStream(out, tokenService.getEncryptionCipher(initializationVector, keyAndCache, Version.V_7_0_0));
StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) {
encryptedStreamOutput.setVersion(Version.V_7_0_0);
encryptedStreamOutput.writeString(userToken.getId());
encryptedStreamOutput.close();
return new String(os.toByteArray(), StandardCharsets.UTF_8);
}
}
}
} }

View File

@ -25,7 +25,8 @@ public class TokensInvalidationResultTests extends ESTestCase {
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"),
Arrays.asList("token3", "token4"), Arrays.asList("token3", "token4"),
Arrays.asList(new ElasticsearchException("foo", new IllegalStateException("bar")), Arrays.asList(new ElasticsearchException("foo", new IllegalStateException("bar")),
new ElasticsearchException("boo", new IllegalStateException("far")))); new ElasticsearchException("boo", new IllegalStateException("far"))),
randomIntBetween(0, 5));
try (XContentBuilder builder = JsonXContent.contentBuilder()) { try (XContentBuilder builder = JsonXContent.contentBuilder()) {
result.toXContent(builder, ToXContent.EMPTY_PARAMS); result.toXContent(builder, ToXContent.EMPTY_PARAMS);
@ -55,8 +56,9 @@ public class TokensInvalidationResultTests extends ESTestCase {
} }
public void testToXcontentWithNoErrors() throws Exception{ public void testToXcontentWithNoErrors() throws Exception{
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), Collections.emptyList(), TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"),
Collections.emptyList()); Collections.emptyList(),
Collections.emptyList(), randomIntBetween(0, 5));
try (XContentBuilder builder = JsonXContent.contentBuilder()) { try (XContentBuilder builder = JsonXContent.contentBuilder()) {
result.toXContent(builder, ToXContent.EMPTY_PARAMS); result.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertThat(Strings.toString(builder), assertThat(Strings.toString(builder),