Only call listener once (SP template registration) (#60567)

This fixes a bug in the IdP's template registration that would
sometimes call the listener twice.

Resolves: #54285
Resolves: #54423

Backport of: #60497
This commit is contained in:
Tim Vernum 2020-08-03 13:45:16 +10:00 committed by GitHub
parent ac258f10d6
commit 1a373b0c21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 20 deletions

View File

@ -205,6 +205,7 @@ public class SamlServiceProviderIndex implements Closeable {
final ClusterState state = clusterService.state(); final ClusterState state = clusterService.state();
if (isTemplateUpToDate(state)) { if (isTemplateUpToDate(state)) {
listener.onResponse(false); listener.onResponse(false);
return;
} }
final String template = TemplateUtils.loadTemplate(TEMPLATE_RESOURCE, Version.CURRENT.toString(), TEMPLATE_VERSION_SUBSTITUTE); final String template = TemplateUtils.loadTemplate(TEMPLATE_RESOURCE, Version.CURRENT.toString(), TEMPLATE_VERSION_SUBSTITUTE);
final PutIndexTemplateRequest request = new PutIndexTemplateRequest(TEMPLATE_NAME).source(template, XContentType.JSON); final PutIndexTemplateRequest request = new PutIndexTemplateRequest(TEMPLATE_NAME).source(template, XContentType.JSON);

View File

@ -282,13 +282,13 @@ public class SamlIdentityProviderTests extends IdentityProviderIntegTestCase {
assertThat(e.getResponse().getStatusLine().getStatusCode(), equalTo(RestStatus.BAD_REQUEST.getStatus())); assertThat(e.getResponse().getStatusLine().getStatusCode(), equalTo(RestStatus.BAD_REQUEST.getStatus()));
} }
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54285")
public void testSpInitiatedSsoFailsForMalformedRequest() throws Exception { public void testSpInitiatedSsoFailsForMalformedRequest() throws Exception {
String acsUrl = "https://" + randomAlphaOfLength(12) + ".elastic-cloud.com/saml/acs"; String acsUrl = "https://" + randomAlphaOfLength(12) + ".elastic-cloud.com/saml/acs";
String entityId = SP_ENTITY_ID; String entityId = SP_ENTITY_ID;
registerServiceProvider(entityId, acsUrl); registerServiceProvider(entityId, acsUrl);
registerApplicationPrivileges(); registerApplicationPrivileges();
ensureGreen(SamlServiceProviderIndex.INDEX_NAME); ensureGreen(SamlServiceProviderIndex.INDEX_NAME);
// Validate incoming authentication request // Validate incoming authentication request
Request validateRequest = new Request("POST", "/_idp/saml/validate"); Request validateRequest = new Request("POST", "/_idp/saml/validate");
validateRequest.setOptions(REQUEST_OPTIONS_AS_CONSOLE_USER); validateRequest.setOptions(REQUEST_OPTIONS_AS_CONSOLE_USER);
@ -298,12 +298,14 @@ public class SamlIdentityProviderTests extends IdentityProviderIntegTestCase {
final AuthnRequest authnRequest = buildAuthnRequest(entityId + randomAlphaOfLength(4), new URL(acsUrl), final AuthnRequest authnRequest = buildAuthnRequest(entityId + randomAlphaOfLength(4), new URL(acsUrl),
new URL("https://idp.org/sso/redirect"), nameIdFormat, forceAuthn); new URL("https://idp.org/sso/redirect"), nameIdFormat, forceAuthn);
final String query = getQueryString(authnRequest, relayString, false, null); final String query = getQueryString(authnRequest, relayString, false, null);
// Skip http parameter name // Skip http parameter name
final String queryWithoutParam = query.substring(12); final String queryWithoutParam = query.substring(12);
validateRequest.setJsonEntity("{\"authn_request_query\":\"" + queryWithoutParam + "\"}"); validateRequest.setJsonEntity("{\"authn_request_query\":\"" + queryWithoutParam + "\"}");
ResponseException e = expectThrows(ResponseException.class, () -> getRestClient().performRequest(validateRequest)); ResponseException e = expectThrows(ResponseException.class, () -> getRestClient().performRequest(validateRequest));
assertThat(e.getMessage(), containsString("does not contain a SAMLRequest parameter")); assertThat(e.getMessage(), containsString("does not contain a SAMLRequest parameter"));
assertThat(e.getResponse().getStatusLine().getStatusCode(), equalTo(RestStatus.BAD_REQUEST.getStatus())); assertThat(e.getResponse().getStatusLine().getStatusCode(), equalTo(RestStatus.BAD_REQUEST.getStatus()));
// arbitrarily trim the request // arbitrarily trim the request
final String malformedRequestQuery = query.substring(0, query.length() - randomIntBetween(10, 15)); final String malformedRequestQuery = query.substring(0, query.length() - randomIntBetween(10, 15));
validateRequest.setJsonEntity("{\"authn_request_query\":\"" + malformedRequestQuery + "\"}"); validateRequest.setJsonEntity("{\"authn_request_query\":\"" + malformedRequestQuery + "\"}");

View File

@ -36,6 +36,7 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.elasticsearch.xpack.idp.saml.idp.SamlIdentityProviderBuilder.IDP_ENTITY_ID; import static org.elasticsearch.xpack.idp.saml.idp.SamlIdentityProviderBuilder.IDP_ENTITY_ID;
@ -132,9 +133,7 @@ public class SamlServiceProviderIndexTests extends ESSingleNodeTestCase {
} }
public void testWritesViaAliasIfItExists() { public void testWritesViaAliasIfItExists() {
final PlainActionFuture<Boolean> installTemplate = new PlainActionFuture<>(); assertTrue(installTemplate());
serviceProviderIndex.installIndexTemplate(installTemplate);
assertTrue(installTemplate.actionGet());
// Create an index that will trigger the template, but isn't the standard index name // Create an index that will trigger the template, but isn't the standard index name
final String customIndexName = SamlServiceProviderIndex.INDEX_NAME + "-test"; final String customIndexName = SamlServiceProviderIndex.INDEX_NAME + "-test";
@ -172,9 +171,7 @@ public class SamlServiceProviderIndexTests extends ESSingleNodeTestCase {
assertBusy(() -> assertThat("template should have been installed", templateMeta, notNullValue())); assertBusy(() -> assertThat("template should have been installed", templateMeta, notNullValue()));
final PlainActionFuture<Boolean> installTemplate = new PlainActionFuture<>(); assertFalse("Template is already installed, should not install again", installTemplate());
serviceProviderIndex.installIndexTemplate(installTemplate);
assertFalse("Template is already installed, should not install again", installTemplate.actionGet());
} }
public void testInstallTemplateAutomaticallyOnDocumentWrite() { public void testInstallTemplateAutomaticallyOnDocumentWrite() {
@ -186,45 +183,44 @@ public class SamlServiceProviderIndexTests extends ESSingleNodeTestCase {
IndexTemplateMetadata templateMeta = clusterService.state().metadata().templates().get(SamlServiceProviderIndex.TEMPLATE_NAME); IndexTemplateMetadata templateMeta = clusterService.state().metadata().templates().get(SamlServiceProviderIndex.TEMPLATE_NAME);
assertThat("template should have been installed", templateMeta, notNullValue()); assertThat("template should have been installed", templateMeta, notNullValue());
final PlainActionFuture<Boolean> installTemplate = new PlainActionFuture<>(); assertFalse("Template is already installed, should not install again", installTemplate());
serviceProviderIndex.installIndexTemplate(installTemplate);
assertFalse("Template is already installed, should not install again", installTemplate.actionGet());
} }
private boolean installTemplate() { private boolean installTemplate() {
final PlainActionFuture<Boolean> installTemplate = new PlainActionFuture<>(); final PlainActionFuture<Boolean> installTemplate = new PlainActionFuture<>();
serviceProviderIndex.installIndexTemplate(installTemplate); serviceProviderIndex.installIndexTemplate(assertListenerIsOnlyCalledOnce(installTemplate));
return installTemplate.actionGet(); return installTemplate.actionGet();
} }
private Set<SamlServiceProviderDocument> getAllDocs() { private Set<SamlServiceProviderDocument> getAllDocs() {
final PlainActionFuture<Set<SamlServiceProviderDocument>> future = new PlainActionFuture<>(); final PlainActionFuture<Set<SamlServiceProviderDocument>> future = new PlainActionFuture<>();
serviceProviderIndex.findAll(ActionListener.wrap( serviceProviderIndex.findAll(assertListenerIsOnlyCalledOnce(ActionListener.wrap(
set -> future.onResponse(set.stream().map(doc -> doc.document.get()) set -> future.onResponse(set.stream().map(doc -> doc.document.get())
.collect(Collectors.collectingAndThen(Collectors.toSet(), Collections::unmodifiableSet))), .collect(Collectors.collectingAndThen(Collectors.toSet(), Collections::unmodifiableSet))),
future::onFailure future::onFailure
)); )));
return future.actionGet(); return future.actionGet();
} }
private SamlServiceProviderDocument readDocument(String docId) { private SamlServiceProviderDocument readDocument(String docId) {
final PlainActionFuture<SamlServiceProviderIndex.DocumentSupplier> future = new PlainActionFuture<>(); final PlainActionFuture<SamlServiceProviderIndex.DocumentSupplier> future = new PlainActionFuture<>();
serviceProviderIndex.readDocument(docId, future); serviceProviderIndex.readDocument(docId, assertListenerIsOnlyCalledOnce(future));
final SamlServiceProviderIndex.DocumentSupplier supplier = future.actionGet(); final SamlServiceProviderIndex.DocumentSupplier supplier = future.actionGet();
return supplier == null ? null : supplier.getDocument(); return supplier == null ? null : supplier.getDocument();
} }
private void writeDocument(SamlServiceProviderDocument doc) { private void writeDocument(SamlServiceProviderDocument doc) {
final PlainActionFuture<DocWriteResponse> future = new PlainActionFuture<>(); final PlainActionFuture<DocWriteResponse> future = new PlainActionFuture<>();
serviceProviderIndex.writeDocument(doc, DocWriteRequest.OpType.INDEX, WriteRequest.RefreshPolicy.WAIT_UNTIL, future); serviceProviderIndex.writeDocument(doc, DocWriteRequest.OpType.INDEX, WriteRequest.RefreshPolicy.WAIT_UNTIL,
assertListenerIsOnlyCalledOnce(future));
doc.setDocId(future.actionGet().getId()); doc.setDocId(future.actionGet().getId());
} }
private DeleteResponse deleteDocument(SamlServiceProviderDocument doc) { private DeleteResponse deleteDocument(SamlServiceProviderDocument doc) {
final PlainActionFuture<DeleteResponse> future = new PlainActionFuture<>(); final PlainActionFuture<DeleteResponse> future = new PlainActionFuture<>();
serviceProviderIndex.readDocument(doc.docId, ActionListener.wrap( serviceProviderIndex.readDocument(doc.docId, assertListenerIsOnlyCalledOnce(ActionListener.wrap(
info -> serviceProviderIndex.deleteDocument(info.version, WriteRequest.RefreshPolicy.IMMEDIATE, future), info -> serviceProviderIndex.deleteDocument(info.version, WriteRequest.RefreshPolicy.IMMEDIATE, future),
future::onFailure)); future::onFailure)));
return future.actionGet(); return future.actionGet();
} }
@ -236,17 +232,17 @@ public class SamlServiceProviderIndexTests extends ESSingleNodeTestCase {
private Set<SamlServiceProviderDocument> findAllByEntityId(String entityId) { private Set<SamlServiceProviderDocument> findAllByEntityId(String entityId) {
final PlainActionFuture<Set<SamlServiceProviderDocument>> future = new PlainActionFuture<>(); final PlainActionFuture<Set<SamlServiceProviderDocument>> future = new PlainActionFuture<>();
serviceProviderIndex.findByEntityId(entityId, ActionListener.wrap( serviceProviderIndex.findByEntityId(entityId, assertListenerIsOnlyCalledOnce(ActionListener.wrap(
set -> future.onResponse(set.stream().map(doc -> doc.document.get()) set -> future.onResponse(set.stream().map(doc -> doc.document.get())
.collect(Collectors.collectingAndThen(Collectors.toSet(), Collections::unmodifiableSet))), .collect(Collectors.collectingAndThen(Collectors.toSet(), Collections::unmodifiableSet))),
future::onFailure future::onFailure
)); )));
return future.actionGet(); return future.actionGet();
} }
private void refresh() { private void refresh() {
PlainActionFuture<Void> future = new PlainActionFuture<>(); PlainActionFuture<Void> future = new PlainActionFuture<>();
serviceProviderIndex.refresh(future); serviceProviderIndex.refresh(assertListenerIsOnlyCalledOnce(future));
future.actionGet(); future.actionGet();
} }
@ -303,4 +299,13 @@ public class SamlServiceProviderIndexTests extends ESSingleNodeTestCase {
+ randomAlphaOfLengthBetween(4, 8) + "." + randomAlphaOfLengthBetween(2, 4) + "/"; + randomAlphaOfLengthBetween(4, 8) + "." + randomAlphaOfLengthBetween(2, 4) + "/";
} }
private static <T> ActionListener<T> assertListenerIsOnlyCalledOnce(ActionListener<T> delegate) {
final AtomicInteger callCount = new AtomicInteger(0);
return ActionListener.runBefore(delegate, () -> {
if (callCount.incrementAndGet() != 1) {
fail("Listener was called twice");
}
});
}
} }