Authentication should set the proper version on the stream

This commit ensures that the authentication sets the correct version on the stream when it is serialized over the wire
so that there is not a version mismatch between the authentication and the connection it came from.

Original commit: elastic/x-pack-elasticsearch@267d7068f4
This commit is contained in:
jaymode 2017-02-16 07:22:02 -05:00
parent 3572cff0a8
commit ff96939c5f
8 changed files with 43 additions and 22 deletions

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.security; package org.elasticsearch.xpack.security;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.Version;
import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
@ -60,7 +61,7 @@ public class SecurityContext {
* Sets the user forcefully to the provided user. There must not be an existing user in the ThreadContext otherwise an exception * Sets the user forcefully to the provided user. There must not be an existing user in the ThreadContext otherwise an exception
* will be thrown. This method is package private for testing. * will be thrown. This method is package private for testing.
*/ */
void setUser(User user) { void setUser(User user, Version version) {
Objects.requireNonNull(user); Objects.requireNonNull(user);
final Authentication.RealmRef lookedUpBy; final Authentication.RealmRef lookedUpBy;
if (user.runAs() == null) { if (user.runAs() == null) {
@ -71,7 +72,7 @@ public class SecurityContext {
try { try {
Authentication authentication = Authentication authentication =
new Authentication(user, new Authentication.RealmRef("__attach", "__attach", nodeName), lookedUpBy); new Authentication(user, new Authentication.RealmRef("__attach", "__attach", nodeName), lookedUpBy, version);
authentication.writeToContext(threadContext); authentication.writeToContext(threadContext);
} catch (IOException e) { } catch (IOException e) {
throw new AssertionError("how can we have a IOException with a user we set", e); throw new AssertionError("how can we have a IOException with a user we set", e);
@ -82,10 +83,10 @@ public class SecurityContext {
* Runs the consumer in a new context as the provided user. The original constext is provided to the consumer. When this method * Runs the consumer in a new context as the provided user. The original constext is provided to the consumer. When this method
* returns, the original context is restored. * returns, the original context is restored.
*/ */
public void executeAsUser(User user, Consumer<StoredContext> consumer) { public void executeAsUser(User user, Consumer<StoredContext> consumer, Version version) {
final StoredContext original = threadContext.newStoredContext(true); final StoredContext original = threadContext.newStoredContext(true);
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
setUser(user); setUser(user, version);
consumer.accept(original); consumer.accept(original);
} }
} }

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.security.action.filter; package org.elasticsearch.xpack.security.action.filter;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionResponse;
@ -121,7 +122,7 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil
} catch (IOException e) { } catch (IOException e) {
listener.onFailure(e); listener.onFailure(e);
} }
}); }, Version.CURRENT);
} else { } else {
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = threadContext.newStoredContext(true)) {
applyInternal(action, request, authenticatedListener); applyInternal(action, request, authenticatedListener);

View File

@ -27,10 +27,14 @@ public class Authentication {
private final Version version; private final Version version;
public Authentication(User user, RealmRef authenticatedBy, RealmRef lookedUpBy) { public Authentication(User user, RealmRef authenticatedBy, RealmRef lookedUpBy) {
this(user, authenticatedBy, lookedUpBy, Version.CURRENT);
}
public Authentication(User user, RealmRef authenticatedBy, RealmRef lookedUpBy, Version version) {
this.user = Objects.requireNonNull(user); this.user = Objects.requireNonNull(user);
this.authenticatedBy = Objects.requireNonNull(authenticatedBy); this.authenticatedBy = Objects.requireNonNull(authenticatedBy);
this.lookedUpBy = lookedUpBy; this.lookedUpBy = lookedUpBy;
this.version = Version.CURRENT; this.version = version;
} }
public Authentication(StreamInput in) throws IOException { public Authentication(StreamInput in) throws IOException {
@ -147,7 +151,7 @@ public class Authentication {
String encode() throws IOException { String encode() throws IOException {
BytesStreamOutput output = new BytesStreamOutput(); BytesStreamOutput output = new BytesStreamOutput();
Version.writeVersion(Version.CURRENT, output); Version.writeVersion(version, output);
writeTo(output); writeTo(output);
return Base64.getEncoder().encodeToString(BytesReference.toBytes(output.bytes())); return Base64.getEncoder().encodeToString(BytesReference.toBytes(output.bytes()));
} }

View File

@ -92,7 +92,7 @@ public class SecurityServerTransportInterceptor extends AbstractComponent implem
if (AuthorizationUtils.shouldReplaceUserWithSystem(threadPool.getThreadContext(), action)) { if (AuthorizationUtils.shouldReplaceUserWithSystem(threadPool.getThreadContext(), action)) {
securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> sendWithUser(connection, action, request, options, securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> sendWithUser(connection, action, request, options,
new TransportService.ContextRestoreResponseHandler<>(threadPool.getThreadContext().wrapRestorable(original) new TransportService.ContextRestoreResponseHandler<>(threadPool.getThreadContext().wrapRestorable(original)
, handler), sender)); , handler), sender), connection.getVersion());
} else if (reservedRealmEnabled && connection.getVersion().before(Version.V_5_2_0_UNRELEASED) && } else if (reservedRealmEnabled && connection.getVersion().before(Version.V_5_2_0_UNRELEASED) &&
KibanaUser.NAME.equals(securityContext.getUser().principal())) { KibanaUser.NAME.equals(securityContext.getUser().principal())) {
final User kibanaUser = securityContext.getUser(); final User kibanaUser = securityContext.getUser();
@ -100,7 +100,15 @@ public class SecurityServerTransportInterceptor extends AbstractComponent implem
kibanaUser.email(), kibanaUser.metadata(), kibanaUser.enabled()); kibanaUser.email(), kibanaUser.metadata(), kibanaUser.enabled());
securityContext.executeAsUser(bwcKibanaUser, (original) -> sendWithUser(connection, action, request, options, securityContext.executeAsUser(bwcKibanaUser, (original) -> sendWithUser(connection, action, request, options,
new TransportService.ContextRestoreResponseHandler<>(threadPool.getThreadContext().wrapRestorable(original), new TransportService.ContextRestoreResponseHandler<>(threadPool.getThreadContext().wrapRestorable(original),
handler), sender)); handler), sender), connection.getVersion());
} else if (securityContext.getAuthentication() != null &&
securityContext.getAuthentication().getVersion().equals(connection.getVersion()) == false) {
// re-write the authentication since we want the authentication version to match the version of the connection
securityContext.executeAsUser(securityContext.getUser(),
(original) -> sendWithUser(connection, action, request, options,
new TransportService.ContextRestoreResponseHandler<>(
threadPool.getThreadContext().wrapRestorable(original), handler), sender),
connection.getVersion());
} else { } else {
sendWithUser(connection, action, request, options, handler, sender); sendWithUser(connection, action, request, options, handler, sender);
} }

View File

@ -137,7 +137,7 @@ public interface ServerTransportFilter {
listener.onResponse(null); listener.onResponse(null);
}); });
asyncAuthorizer.authorize(authzService); asyncAuthorizer.authorize(authzService);
}); }, transportChannel.getVersion());
} else { } else {
throw new IllegalStateException("a disabled user should never be sent. " + kibanaUser); throw new IllegalStateException("a disabled user should never be sent. " + kibanaUser);
} }
@ -151,7 +151,7 @@ public interface ServerTransportFilter {
listener.onResponse(null); listener.onResponse(null);
}); });
asyncAuthorizer.authorize(authzService); asyncAuthorizer.authorize(authzService);
}); }, transportChannel.getVersion());
} else { } else {
final AuthorizationUtils.AsyncAuthorizer asyncAuthorizer = final AuthorizationUtils.AsyncAuthorizer asyncAuthorizer =
new AuthorizationUtils.AsyncAuthorizer(authentication, listener, (userRoles, runAsRoles) -> { new AuthorizationUtils.AsyncAuthorizer(authentication, listener, (userRoles, runAsRoles) -> {

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.security; package org.elasticsearch.xpack.security;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.concurrent.ThreadContext.StoredContext; import org.elasticsearch.common.util.concurrent.ThreadContext.StoredContext;
@ -51,11 +52,11 @@ public class SecurityContextTests extends ESTestCase {
final User user = new User("test"); final User user = new User("test");
assertNull(securityContext.getAuthentication()); assertNull(securityContext.getAuthentication());
assertNull(securityContext.getUser()); assertNull(securityContext.getUser());
securityContext.setUser(user); securityContext.setUser(user, Version.CURRENT);
assertEquals(user, securityContext.getUser()); assertEquals(user, securityContext.getUser());
IllegalStateException e = expectThrows(IllegalStateException.class, IllegalStateException e = expectThrows(IllegalStateException.class,
() -> securityContext.setUser(randomFrom(user, SystemUser.INSTANCE))); () -> securityContext.setUser(randomFrom(user, SystemUser.INSTANCE), Version.CURRENT));
assertEquals("authentication is already present in the context", e.getMessage()); assertEquals("authentication is already present in the context", e.getMessage());
} }
@ -74,7 +75,7 @@ public class SecurityContextTests extends ESTestCase {
securityContext.executeAsUser(executionUser, (originalCtx) -> { securityContext.executeAsUser(executionUser, (originalCtx) -> {
assertEquals(executionUser, securityContext.getUser()); assertEquals(executionUser, securityContext.getUser());
contextAtomicReference.set(originalCtx); contextAtomicReference.set(originalCtx);
}); }, Version.CURRENT);
final User userAfterExecution = securityContext.getUser(); final User userAfterExecution = securityContext.getUser();
assertEquals(original, userAfterExecution); assertEquals(original, userAfterExecution);

View File

@ -14,6 +14,7 @@ import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.Transport.Connection;
import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportInterceptor.AsyncSender; import org.elasticsearch.transport.TransportInterceptor.AsyncSender;
import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequest;
@ -40,6 +41,7 @@ import java.util.function.Consumer;
import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.arrayContaining;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
@ -120,7 +122,7 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase {
assertEquals(user, sendingUser.get()); assertEquals(user, sendingUser.get());
assertEquals(user, securityContext.getUser()); assertEquals(user, securityContext.getUser());
verify(xPackLicenseState).isAuthAllowed(); verify(xPackLicenseState).isAuthAllowed();
verify(securityContext, never()).executeAsUser(any(User.class), any(Consumer.class)); verify(securityContext, never()).executeAsUser(any(User.class), any(Consumer.class), any(Version.class));
verifyNoMoreInteractions(xPackLicenseState); verifyNoMoreInteractions(xPackLicenseState);
} }
@ -147,13 +149,15 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase {
sendingUser.set(securityContext.getUser()); sendingUser.set(securityContext.getUser());
} }
}); });
sender.sendRequest(null, "internal:foo", null, null, null); Connection connection = mock(Connection.class);
when(connection.getVersion()).thenReturn(Version.CURRENT);
sender.sendRequest(connection, "internal:foo", null, null, null);
assertTrue(calledWrappedSender.get()); assertTrue(calledWrappedSender.get());
assertNotEquals(user, sendingUser.get()); assertNotEquals(user, sendingUser.get());
assertEquals(SystemUser.INSTANCE, sendingUser.get()); assertEquals(SystemUser.INSTANCE, sendingUser.get());
assertEquals(user, securityContext.getUser()); assertEquals(user, securityContext.getUser());
verify(xPackLicenseState).isAuthAllowed(); verify(xPackLicenseState).isAuthAllowed();
verify(securityContext).executeAsUser(any(User.class), any(Consumer.class)); verify(securityContext).executeAsUser(any(User.class), any(Consumer.class), eq(Version.CURRENT));
verifyNoMoreInteractions(xPackLicenseState); verifyNoMoreInteractions(xPackLicenseState);
} }
@ -178,7 +182,7 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase {
assertEquals("there should always be a user when sending a message", e.getMessage()); assertEquals("there should always be a user when sending a message", e.getMessage());
assertNull(securityContext.getUser()); assertNull(securityContext.getUser());
verify(xPackLicenseState).isAuthAllowed(); verify(xPackLicenseState).isAuthAllowed();
verify(securityContext, never()).executeAsUser(any(User.class), any(Consumer.class)); verify(securityContext, never()).executeAsUser(any(User.class), any(Consumer.class), any(Version.class));
verifyNoMoreInteractions(xPackLicenseState); verifyNoMoreInteractions(xPackLicenseState);
} }
@ -207,7 +211,8 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase {
}; };
AsyncSender sender = interceptor.interceptSender(intercepted); AsyncSender sender = interceptor.interceptSender(intercepted);
Transport.Connection connection = mock(Transport.Connection.class); Transport.Connection connection = mock(Transport.Connection.class);
when(connection.getVersion()).thenReturn(Version.fromId(randomIntBetween(Version.V_5_0_0_ID, Version.V_5_2_0_ID_UNRELEASED - 100))); final Version version = Version.fromId(randomIntBetween(Version.V_5_0_0_ID, Version.V_5_2_0_ID_UNRELEASED - 100));
when(connection.getVersion()).thenReturn(version);
sender.sendRequest(connection, "indices:foo[s]", null, null, null); sender.sendRequest(connection, "indices:foo[s]", null, null, null);
assertTrue(calledWrappedSender.get()); assertTrue(calledWrappedSender.get());
assertNotEquals(user, sendingUser.get()); assertNotEquals(user, sendingUser.get());
@ -238,7 +243,7 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase {
assertEquals(user, sendingUser.get()); assertEquals(user, sendingUser.get());
verify(xPackLicenseState, times(3)).isAuthAllowed(); verify(xPackLicenseState, times(3)).isAuthAllowed();
verify(securityContext, times(1)).executeAsUser(any(User.class), any(Consumer.class)); verify(securityContext, times(1)).executeAsUser(any(User.class), any(Consumer.class), eq(version));
verifyNoMoreInteractions(xPackLicenseState); verifyNoMoreInteractions(xPackLicenseState);
} }

View File

@ -225,8 +225,8 @@ public class ServerTransportFilterTests extends ESTestCase {
TransportRequest request = mock(TransportRequest.class); TransportRequest request = mock(TransportRequest.class);
User user = new User("kibana", "kibana"); User user = new User("kibana", "kibana");
Authentication authentication = mock(Authentication.class); Authentication authentication = mock(Authentication.class);
when(authentication.getVersion()) final Version version = Version.fromId(randomIntBetween(Version.V_5_0_0_ID, Version.V_5_2_0_ID_UNRELEASED - 100));
.thenReturn(Version.fromId(randomIntBetween(Version.V_5_0_0_ID, Version.V_5_2_0_ID_UNRELEASED - 100))); when(authentication.getVersion()).thenReturn(version);
when(authentication.getUser()).thenReturn(user); when(authentication.getUser()).thenReturn(user);
when(authentication.getRunAsUser()).thenReturn(user); when(authentication.getRunAsUser()).thenReturn(user);
doAnswer((i) -> { doAnswer((i) -> {
@ -246,6 +246,7 @@ public class ServerTransportFilterTests extends ESTestCase {
}).when(authzService).roles(any(User.class), any(ActionListener.class)); }).when(authzService).roles(any(User.class), any(ActionListener.class));
ServerTransportFilter filter = getClientOrNodeFilter(); ServerTransportFilter filter = getClientOrNodeFilter();
PlainActionFuture<Void> future = new PlainActionFuture<>(); PlainActionFuture<Void> future = new PlainActionFuture<>();
when(channel.getVersion()).thenReturn(version);
filter.inbound("_action", request, channel, future); filter.inbound("_action", request, channel, future);
assertNotNull(rolesRef.get()); assertNotNull(rolesRef.get());
assertThat(rolesRef.get(), arrayContaining("kibana_system")); assertThat(rolesRef.get(), arrayContaining("kibana_system"));