From aa61494fc3a68b4806784f67ad837ee821d26da4 Mon Sep 17 00:00:00 2001 From: Tamas Palfy Date: Tue, 27 Jul 2021 15:19:53 +0200 Subject: [PATCH] NIFI-9065 Add support for OAuth2AccessTokenProvider in InvokeHTTP Signed-off-by: Pierre Villard This closes #5319. --- .../nifi-standard-processors/pom.xml | 5 + .../nifi/processors/standard/InvokeHTTP.java | 52 ++- .../processors/standard/InvokeHTTPTest.java | 95 ++++ .../org/apache/nifi/oauth2/AccessToken.java | 63 ++- .../oauth2/OAuth2AccessTokenProvider.java | 30 ++ .../nifi/oauth2/OAuth2TokenProvider.java | 3 + .../nifi/oauth2/OAuth2TokenProviderImpl.java | 3 + .../StandardOauth2AccessTokenProvider.java | 300 +++++++++++++ ...g.apache.nifi.controller.ControllerService | 2 + .../oauth2/OAuth2TokenProviderImplTest.java | 4 +- ...StandardOauth2AccessTokenProviderTest.java | 411 ++++++++++++++++++ 11 files changed, 944 insertions(+), 24 deletions(-) create mode 100644 nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2AccessTokenProvider.java create mode 100644 nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java create mode 100644 nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/pom.xml b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/pom.xml index 84fffaf243..8f6ca2d22a 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/pom.xml +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/pom.xml @@ -374,6 +374,11 @@ nifi-database-utils 1.16.0-SNAPSHOT + + org.apache.nifi + nifi-oauth2-provider-api + 1.16.0-SNAPSHOT + org.apache.sshd sshd-core diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/InvokeHTTP.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/InvokeHTTP.java index f9977bcf62..fbfea08c8c 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/InvokeHTTP.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/InvokeHTTP.java @@ -42,6 +42,7 @@ import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -90,6 +91,7 @@ import org.apache.nifi.expression.ExpressionLanguageScope; import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.flowfile.attributes.CoreAttributes; import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.oauth2.OAuth2AccessTokenProvider; import org.apache.nifi.processor.AbstractProcessor; import org.apache.nifi.processor.DataUnit; import org.apache.nifi.processor.ProcessContext; @@ -494,6 +496,13 @@ public class InvokeHTTP extends AbstractProcessor { .allowableValues("True", "False") .build(); + public static final PropertyDescriptor OAUTH2_ACCESS_TOKEN_PROVIDER = new PropertyDescriptor.Builder() + .name("oauth2-access-token-provider") + .displayName("OAuth2 Access Token provider") + .identifiesControllerService(OAuth2AccessTokenProvider.class) + .required(false) + .build(); + public static final PropertyDescriptor FLOW_FILE_NAMING_STRATEGY = new PropertyDescriptor.Builder() .name("flow-file-naming-strategy") .description("Determines the strategy used for setting the filename attribute of the FlowFile.") @@ -527,6 +536,7 @@ public class InvokeHTTP extends AbstractProcessor { PROP_USERAGENT, PROP_BASIC_AUTH_USERNAME, PROP_BASIC_AUTH_PASSWORD, + OAUTH2_ACCESS_TOKEN_PROVIDER, PROXY_CONFIGURATION_SERVICE, PROP_PROXY_HOST, PROP_PROXY_PORT, @@ -595,6 +605,8 @@ public class InvokeHTTP extends AbstractProcessor { private volatile boolean useChunked = false; + private volatile Optional oauth2AccessTokenProviderOptional; + private final AtomicReference okHttpClientAtomicReference = new AtomicReference<>(); @Override @@ -728,6 +740,19 @@ public class InvokeHTTP extends AbstractProcessor { .build()); } + boolean usingUserNamePasswordAuthorization = validationContext.getProperty(PROP_BASIC_AUTH_USERNAME).isSet() + || validationContext.getProperty(PROP_BASIC_AUTH_PASSWORD).isSet(); + + boolean usingOAuth2Authorization = validationContext.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet(); + + if (usingUserNamePasswordAuthorization && usingOAuth2Authorization) { + results.add(new ValidationResult.Builder() + .subject("Authorization properties") + .valid(false) + .explanation("OAuth2 Authorization cannot be configured together with Username and Password properties") + .build()); + } + return results; } @@ -806,6 +831,19 @@ public class InvokeHTTP extends AbstractProcessor { okHttpClientAtomicReference.set(okHttpClientBuilder.build()); } + @OnScheduled + public void initOauth2AccessTokenProvider(final ProcessContext context) { + if (context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet()) { + OAuth2AccessTokenProvider oauth2AccessTokenProvider = context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).asControllerService(OAuth2AccessTokenProvider.class); + + oauth2AccessTokenProvider.getAccessDetails(); + + oauth2AccessTokenProviderOptional = Optional.of(oauth2AccessTokenProvider); + } else { + oauth2AccessTokenProviderOptional = Optional.empty(); + } + } + private void setAuthenticator(OkHttpClient.Builder okHttpClientBuilder, ProcessContext context) { final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue()); @@ -1034,11 +1072,17 @@ public class InvokeHTTP extends AbstractProcessor { final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue()); // If the username/password properties are set then check if digest auth is being used - if (!authUser.isEmpty() && "false".equalsIgnoreCase(context.getProperty(PROP_DIGEST_AUTH).getValue())) { - final String authPass = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_PASSWORD).getValue()); + if ("false".equalsIgnoreCase(context.getProperty(PROP_DIGEST_AUTH).getValue())) { + if (!authUser.isEmpty()) { + final String authPass = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_PASSWORD).getValue()); - String credential = Credentials.basic(authUser, authPass); - requestBuilder.header("Authorization", credential); + String credential = Credentials.basic(authUser, authPass); + requestBuilder.header("Authorization", credential); + } else { + oauth2AccessTokenProviderOptional.ifPresent(oauth2AccessTokenProvider -> + requestBuilder.addHeader("Authorization", "Bearer " + oauth2AccessTokenProvider.getAccessDetails().getAccessToken()) + ); + } } // set the request method diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/InvokeHTTPTest.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/InvokeHTTPTest.java index 469ba88000..ed5d0b832d 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/InvokeHTTPTest.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/InvokeHTTPTest.java @@ -21,6 +21,7 @@ import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; import org.apache.commons.lang3.StringUtils; import org.apache.nifi.flowfile.attributes.CoreAttributes; +import org.apache.nifi.oauth2.OAuth2AccessTokenProvider; import org.apache.nifi.processor.Relationship; import org.apache.nifi.processors.standard.http.FlowFileNamingStrategy; import org.apache.nifi.reporting.InitializationException; @@ -34,6 +35,7 @@ import org.apache.nifi.util.MockFlowFile; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; import org.apache.nifi.web.util.ssl.SslContextUtils; +import org.mockito.Answers; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; @@ -774,6 +776,99 @@ public class InvokeHTTPTest { ); } + @Test + public void testValidWhenOAuth2Set() throws Exception { + // GIVEN + String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId"; + + OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS); + when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId); + + runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider); + runner.enableControllerService(oauth2AccessTokenProvider); + + setUrlProperty(); + + // WHEN + runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId); + + // THEN + runner.assertValid(); + } + + @Test + public void testInvalidWhenOAuth2AndUserNameSet() throws Exception { + // GIVEN + String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId"; + + OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS); + when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId); + + runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider); + runner.enableControllerService(oauth2AccessTokenProvider); + + setUrlProperty(); + + // WHEN + runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId); + runner.setProperty(InvokeHTTP.PROP_BASIC_AUTH_USERNAME, "userName"); + + // THEN + runner.assertNotValid(); + } + + @Test + public void testInvalidWhenOAuth2AndPasswordSet() throws Exception { + // GIVEN + String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId"; + + OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS); + when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId); + + runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider); + runner.enableControllerService(oauth2AccessTokenProvider); + + setUrlProperty(); + + // WHEN + runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId); + runner.setProperty(InvokeHTTP.PROP_BASIC_AUTH_PASSWORD, "password"); + + // THEN + runner.assertNotValid(); + } + + @Test + public void testOAuth2AuthorizationHeader() throws Exception { + // GIVEN + String accessToken = "access_token"; + + String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId"; + + OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS); + when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId); + when(oauth2AccessTokenProvider.getAccessDetails().getAccessToken()).thenReturn(accessToken); + + runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider); + runner.enableControllerService(oauth2AccessTokenProvider); + + setUrlProperty(); + + mockWebServer.enqueue(new MockResponse()); + + // WHEN + runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId); + runner.enqueue("unimportant"); + runner.run(); + + // THEN + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + + String actualAuthorizationHeader = recordedRequest.getHeader("Authorization"); + assertEquals("Bearer " + accessToken, actualAuthorizationHeader); + + } + private void setUrlProperty() { runner.setProperty(InvokeHTTP.PROP_URL, getMockWebServerUrl()); } diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/AccessToken.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/AccessToken.java index 2e261ddce0..622c9b0934 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/AccessToken.java +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/AccessToken.java @@ -17,53 +17,80 @@ package org.apache.nifi.oauth2; +import java.time.Duration; +import java.time.Instant; + public class AccessToken { private String accessToken; private String refreshToken; private String tokenType; - private Integer expires; - private String scope; + private long expiresIn; + private String scopes; - private Long fetchTime; + private final Instant fetchTime; - public AccessToken(String accessToken, - String refreshToken, - String tokenType, - Integer expires, - String scope) { + public static final int EXPIRY_MARGIN = 5000; + + public AccessToken() { + this.fetchTime = Instant.now(); + } + + public AccessToken(String accessToken, String refreshToken, String tokenType, long expiresIn, String scopes) { + this(); this.accessToken = accessToken; - this.tokenType = tokenType; this.refreshToken = refreshToken; - this.expires = expires; - this.scope = scope; - this.fetchTime = System.currentTimeMillis(); + this.tokenType = tokenType; + this.expiresIn = expiresIn; + this.scopes = scopes; } public String getAccessToken() { return accessToken; } + public void setAccessToken(String accessToken) { + this.accessToken = accessToken; + } + public String getRefreshToken() { return refreshToken; } + public void setRefreshToken(String refreshToken) { + this.refreshToken = refreshToken; + } + public String getTokenType() { return tokenType; } - public Integer getExpires() { - return expires; + public void setTokenType(String tokenType) { + this.tokenType = tokenType; } - public String getScope() { - return scope; + public long getExpiresIn() { + return expiresIn; } - public Long getFetchTime() { + public void setExpiresIn(long expiresIn) { + this.expiresIn = expiresIn; + } + + public String getScopes() { + return scopes; + } + + public void setScopes(String scopes) { + this.scopes = scopes; + } + + public Instant getFetchTime() { return fetchTime; } public boolean isExpired() { - return System.currentTimeMillis() >= ( fetchTime + (expires * 1000) ); + boolean expired = Duration.between(Instant.now(), fetchTime.plusSeconds(expiresIn - EXPIRY_MARGIN)).isNegative(); + + return expired; } } diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2AccessTokenProvider.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2AccessTokenProvider.java new file mode 100644 index 0000000000..9cfc7578f5 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2AccessTokenProvider.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.oauth2; + +import org.apache.nifi.controller.ControllerService; + +/** + * Controller service that provides OAuth2 access details + */ +public interface OAuth2AccessTokenProvider extends ControllerService { + + /** + * @return A valid access token (refreshed automatically if needed) and additional metadata (provided by the OAuth2 access server) + */ + AccessToken getAccessDetails(); +} diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProvider.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProvider.java index 964ae439d5..707ff6a946 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProvider.java +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-api/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProvider.java @@ -29,7 +29,10 @@ import java.util.List; /** * Interface for defining a credential-providing controller service for oauth2 processes. + * + * @deprecated use {@link OAuth2AccessTokenProvider} instead */ +@Deprecated public interface OAuth2TokenProvider extends ControllerService { PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder() .name("oauth2-ssl-context") diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProviderImpl.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProviderImpl.java index 056c4f2356..cc06309c2e 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProviderImpl.java +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/OAuth2TokenProviderImpl.java @@ -30,6 +30,7 @@ import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; import org.apache.nifi.annotation.documentation.CapabilityDescription; +import org.apache.nifi.annotation.documentation.DeprecationNotice; import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.lifecycle.OnEnabled; import org.apache.nifi.components.PropertyDescriptor; @@ -39,6 +40,8 @@ import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.ssl.SSLContextService; import org.apache.nifi.util.StringUtils; +@Deprecated +@DeprecationNotice(alternatives = {StandardOauth2AccessTokenProvider.class}) @Tags({"oauth2", "provider", "authorization" }) @CapabilityDescription("This controller service provides a way of working with access and refresh tokens via the " + "password and client_credential grant flows in the OAuth2 specification. It is meant to provide a way for components " + diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java new file mode 100644 index 0000000000..6c1b05e381 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.oauth2; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import okhttp3.FormBody; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.apache.nifi.annotation.documentation.CapabilityDescription; +import org.apache.nifi.annotation.documentation.Tags; +import org.apache.nifi.annotation.lifecycle.OnDisabled; +import org.apache.nifi.annotation.lifecycle.OnEnabled; +import org.apache.nifi.components.AllowableValue; +import org.apache.nifi.components.PropertyDescriptor; +import org.apache.nifi.components.ValidationContext; +import org.apache.nifi.components.ValidationResult; +import org.apache.nifi.components.Validator; +import org.apache.nifi.controller.AbstractControllerService; +import org.apache.nifi.controller.ConfigurationContext; +import org.apache.nifi.expression.ExpressionLanguageScope; +import org.apache.nifi.processor.exception.ProcessException; +import org.apache.nifi.processor.util.StandardValidators; +import org.apache.nifi.ssl.SSLContextService; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.X509TrustManager; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +@Tags({"oauth2", "provider", "authorization", "access token", "http"}) +@CapabilityDescription("Provides OAuth 2.0 access tokens that can be used as Bearer authorization header in HTTP requests." + + " Uses Resource Owner Password Credentials Grant.") +public class StandardOauth2AccessTokenProvider extends AbstractControllerService implements OAuth2AccessTokenProvider { + public static final PropertyDescriptor AUTHORIZATION_SERVER_URL = new PropertyDescriptor.Builder() + .name("authorization-server-url") + .displayName("Authorization Server URL") + .description("The URL of the authorization server that issues access tokens.") + .required(true) + .addValidator(StandardValidators.URL_VALIDATOR) + .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY) + .build(); + + public static AllowableValue RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE = new AllowableValue( + "password", + "User Password", + "Resource Owner Password Credentials Grant. Used to access resources available to users. Requires username and password and usually Client ID and Client Secret" + ); + public static AllowableValue CLIENT_CREDENTIALS_GRANT_TYPE = new AllowableValue( + "client_credentials", + "Client Credentials", + "Client Credentials Grant. Used to access resources available to clients. Requires Client ID and Client Secret" + ); + + public static final PropertyDescriptor GRANT_TYPE = new PropertyDescriptor.Builder() + .name("grant-type") + .displayName("Grant Type") + .description("The OAuth2 Grant Type to be used when acquiring an access token.") + .required(true) + .allowableValues(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE, CLIENT_CREDENTIALS_GRANT_TYPE) + .defaultValue(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue()) + .build(); + + public static final PropertyDescriptor USERNAME = new PropertyDescriptor.Builder() + .name("service-user-name") + .displayName("Username") + .description("Username on the service that is being accessed.") + .dependsOn(GRANT_TYPE, RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE) + .required(true) + .addValidator(StandardValidators.NON_BLANK_VALIDATOR) + .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY) + .build(); + + public static final PropertyDescriptor PASSWORD = new PropertyDescriptor.Builder() + .name("service-password") + .displayName("Password") + .description("Password for the username on the service that is being accessed.") + .dependsOn(GRANT_TYPE, RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE) + .required(true) + .sensitive(true) + .addValidator(StandardValidators.NON_BLANK_VALIDATOR) + .build(); + + public static final PropertyDescriptor CLIENT_ID = new PropertyDescriptor.Builder() + .name("client-id") + .displayName("Client ID") + .required(false) + .addValidator(StandardValidators.NON_BLANK_VALIDATOR) + .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY) + .build(); + + public static final PropertyDescriptor CLIENT_SECRET = new PropertyDescriptor.Builder() + .name("client-secret") + .displayName("Client secret") + .dependsOn(CLIENT_ID) + .required(true) + .sensitive(true) + .addValidator(StandardValidators.NON_BLANK_VALIDATOR) + .build(); + + public static final PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder() + .name("ssl-context-service") + .displayName("SSL Context Servuce") + .addValidator(Validator.VALID) + .identifiesControllerService(SSLContextService.class) + .required(false) + .build(); + + private static final List PROPERTIES = Collections.unmodifiableList(Arrays.asList( + AUTHORIZATION_SERVER_URL, + GRANT_TYPE, + USERNAME, + PASSWORD, + CLIENT_ID, + CLIENT_SECRET, + SSL_CONTEXT + )); + + public static final ObjectMapper ACCESS_DETAILS_MAPPER = new ObjectMapper() + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE); + + private volatile String authorizationServerUrl; + private volatile OkHttpClient httpClient; + + private volatile String grantType; + private volatile String username; + private volatile String password; + private volatile String clientId; + private volatile String clientSecret; + + private volatile AccessToken accessDetails; + + @Override + public List getSupportedPropertyDescriptors() { + return PROPERTIES; + } + + @OnEnabled + public void onEnabled(ConfigurationContext context) { + authorizationServerUrl = context.getProperty(AUTHORIZATION_SERVER_URL).evaluateAttributeExpressions().getValue(); + + httpClient = createHttpClient(context); + + grantType = context.getProperty(GRANT_TYPE).getValue(); + username = context.getProperty(USERNAME).evaluateAttributeExpressions().getValue(); + password = context.getProperty(PASSWORD).getValue(); + clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue(); + clientSecret = context.getProperty(CLIENT_SECRET).getValue(); + } + + @OnDisabled + public void onDisabled() { + accessDetails = null; + } + + @Override + protected Collection customValidate(ValidationContext validationContext) { + final List validationResults = new ArrayList<>(super.customValidate(validationContext)); + + if ( + validationContext.getProperty(GRANT_TYPE).getValue().equals(CLIENT_CREDENTIALS_GRANT_TYPE.getValue()) + && !validationContext.getProperty(CLIENT_ID).isSet() + ) { + validationResults.add(new ValidationResult.Builder().subject(CLIENT_ID.getDisplayName()) + .valid(false) + .explanation(String.format( + "When '%s' is set to '%s', '%s' is required", + GRANT_TYPE.getDisplayName(), + CLIENT_CREDENTIALS_GRANT_TYPE.getDisplayName(), + CLIENT_ID.getDisplayName()) + ) + .build()); + } + + return validationResults; + } + + protected OkHttpClient createHttpClient(ConfigurationContext context) { + OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder(); + + SSLContextService sslService = context.getProperty(SSL_CONTEXT).asControllerService(SSLContextService.class); + if (sslService != null) { + final X509TrustManager trustManager = sslService.createTrustManager(); + SSLContext sslContext = sslService.createContext(); + clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustManager); + } + + return clientBuilder.build(); + } + + @Override + public AccessToken getAccessDetails() { + if (this.accessDetails == null) { + acquireAccessDetails(); + } else if (this.accessDetails.isExpired()) { + if (this.accessDetails.getRefreshToken() == null) { + acquireAccessDetails(); + } else { + try { + refreshAccessDetails(); + } catch (Exception e) { + getLogger().info("Couldn't refresh access token", e); + acquireAccessDetails(); + } + } + } + + return accessDetails; + } + + private void acquireAccessDetails() { + getLogger().debug("Getting a new access token"); + + FormBody.Builder acquireTokenBuilder = new FormBody.Builder(); + + if (grantType.equals(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue())) { + acquireTokenBuilder.add("grant_type", "password") + .add("username", username) + .add("password", password); + } else if (grantType.equals(CLIENT_CREDENTIALS_GRANT_TYPE.getValue())) { + acquireTokenBuilder.add("grant_type", "client_credentials"); + } + + if (clientId != null) { + acquireTokenBuilder.add("client_id", clientId); + acquireTokenBuilder.add("client_secret", clientSecret); + } + + RequestBody acquireTokenRequestBody = acquireTokenBuilder.build(); + + Request acquireTokenRequest = new Request.Builder() + .url(authorizationServerUrl) + .post(acquireTokenRequestBody) + .build(); + + this.accessDetails = getAccessDetails(acquireTokenRequest); + } + + private void refreshAccessDetails() { + getLogger().debug("Refreshing access token"); + + FormBody.Builder refreshTokenBuilder = new FormBody.Builder() + .add("grant_type", "refresh_token") + .add("refresh_token", this.accessDetails.getRefreshToken()); + + if (clientId != null) { + refreshTokenBuilder.add("client_id", clientId); + refreshTokenBuilder.add("client_secret", clientSecret); + } + + RequestBody refreshTokenRequestBody = refreshTokenBuilder.build(); + + Request refreshRequest = new Request.Builder() + .url(authorizationServerUrl) + .post(refreshTokenRequestBody) + .build(); + + this.accessDetails = getAccessDetails(refreshRequest); + } + + private AccessToken getAccessDetails(Request newRequest) { + try { + Response response = httpClient.newCall(newRequest).execute(); + String responseBody = response.body().string(); + if (response.code() != 200) { + getLogger().error(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", response.code(), responseBody)); + throw new ProcessException(String.format("OAuth2 access token request failed [HTTP %d]", response.code())); + } + + AccessToken accessDetails = ACCESS_DETAILS_MAPPER.readValue(responseBody, AccessToken.class); + + return accessDetails; + } catch (IOException e) { + throw new UncheckedIOException("OAuth2 access token request failed", e); + } + } +} diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService index 75e29d02a5..b1de8ed022 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService @@ -13,3 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. org.apache.nifi.oauth2.OAuth2TokenProviderImpl +org.apache.nifi.oauth2.StandardOauth2AccessTokenProvider + diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/OAuth2TokenProviderImplTest.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/OAuth2TokenProviderImplTest.java index 25f04492e2..b8dc62e4a3 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/OAuth2TokenProviderImplTest.java +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/OAuth2TokenProviderImplTest.java @@ -107,7 +107,7 @@ public class OAuth2TokenProviderImplTest { private void assertAccessTokenFound(final AccessToken accessToken) { assertNotNull(accessToken); assertEquals("access token", accessToken.getAccessToken()); - assertEquals(300, accessToken.getExpires().intValue()); + assertEquals(5300, accessToken.getExpiresIn()); assertEquals("BEARER", accessToken.getTokenType()); assertFalse(accessToken.isExpired()); } @@ -117,7 +117,7 @@ public class OAuth2TokenProviderImplTest { token.put("access_token", "access token"); token.put("refresh_token", "refresh token"); token.put("token_type", "BEARER"); - token.put("expires_in", 300); + token.put("expires_in", 5300); token.put("scope", "test scope"); final String accessToken = new ObjectMapper().writeValueAsString(token); mockWebServer.enqueue(new MockResponse().setResponseCode(200).addHeader("Content-Type", "application/json").setBody(accessToken)); diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java new file mode 100644 index 0000000000..54de7a0f92 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.oauth2; + +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Protocol; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; +import org.apache.nifi.controller.ConfigurationContext; +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.Processor; +import org.apache.nifi.processor.exception.ProcessException; +import org.apache.nifi.util.NoOpProcessor; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class StandardOauth2AccessTokenProviderTest { + private static final String AUTHORIZATION_SERVER_URL = "http://authorizationServerUrl"; + private static final String USERNAME = "username"; + private static final String PASSWORD = "password"; + private static final String CLIENT_ID = "clientId"; + private static final String CLIENT_SECRET = "clientSecret"; + + private StandardOauth2AccessTokenProvider testSubject; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private OkHttpClient mockHttpClient; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ConfigurationContext mockContext; + + @Mock + private ComponentLog mockLogger; + @Captor + private ArgumentCaptor debugCaptor; + @Captor + private ArgumentCaptor errorCaptor; + @Captor + private ArgumentCaptor throwableCaptor; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + + testSubject = new StandardOauth2AccessTokenProvider() { + @Override + protected OkHttpClient createHttpClient(ConfigurationContext context) { + return mockHttpClient; + } + + @Override + protected ComponentLog getLogger() { + return mockLogger; + } + }; + + when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue()); + when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL).evaluateAttributeExpressions().getValue()).thenReturn(AUTHORIZATION_SERVER_URL); + when(mockContext.getProperty(StandardOauth2AccessTokenProvider.USERNAME).evaluateAttributeExpressions().getValue()).thenReturn(USERNAME); + when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD); + when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID); + when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET); + + testSubject.onEnabled(mockContext); + } + + @Test + public void testInvalidWhenClientCredentialsGrantTypeSetWithoutClientId() throws Exception { + // GIVEN + Processor processor = new NoOpProcessor(); + TestRunner runner = TestRunners.newTestRunner(processor); + + runner.addControllerService("testSubject", testSubject); + + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant"); + + // WHEN + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE); + + // THEN + runner.assertNotValid(testSubject); + } + + @Test + public void testValidWhenClientCredentialsGrantTypeSetWithClientId() throws Exception { + // GIVEN + Processor processor = new NoOpProcessor(); + TestRunner runner = TestRunners.newTestRunner(processor); + + runner.addControllerService("testSubject", testSubject); + + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant"); + + // WHEN + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, "clientId"); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, "clientSecret"); + + // THEN + runner.assertValid(testSubject); + } + + @Test + public void testAcquireNewToken() throws Exception { + String accessTokenValue = "access_token_value"; + + // GIVEN + Response response = buildResponse( + 200, + "{ \"access_token\":\"" + accessTokenValue + "\" }" + ); + + when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response); + + // WHEN + String actual = testSubject.getAccessDetails().getAccessToken(); + + // THEN + assertEquals(accessTokenValue, actual); + } + + @Test + public void testRefreshToken() throws Exception { + // GIVEN + String firstToken = "first_token"; + String expectedToken = "second_token"; + + Response response1 = buildResponse( + 200, + "{ \"access_token\":\"" + firstToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }" + ); + + Response response2 = buildResponse( + 200, + "{ \"access_token\":\"" + expectedToken + "\" }" + ); + + when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2); + + // WHEN + testSubject.getAccessDetails(); + String actualToken = testSubject.getAccessDetails().getAccessToken(); + + // THEN + assertEquals(expectedToken, actualToken); + } + + @Test + public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception { + // GIVEN + String refreshErrorMessage = "refresh_error"; + String acquireErrorMessage = "acquire_error"; + + AtomicInteger callCounter = new AtomicInteger(0); + when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> { + callCounter.incrementAndGet(); + + if (callCounter.get() == 1) { + return buildSuccessfulInitResponse(); + } else if (callCounter.get() == 2) { + throw new IOException(refreshErrorMessage); + } else if (callCounter.get() == 3) { + throw new IOException(acquireErrorMessage); + } + + throw new IllegalStateException("Test improperly defined mock HTTP responses."); + }); + + // Get a good accessDetails so we can have a refresh a second time + testSubject.getAccessDetails(); + + // WHEN + UncheckedIOException actualException = assertThrows( + UncheckedIOException.class, + () -> testSubject.getAccessDetails() + ); + + // THEN + checkLoggedDebugWhenRefreshFails(); + + checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage))); + + checkError(new UncheckedIOException("OAuth2 access token request failed", new IOException(acquireErrorMessage)), actualException); + } + + @Test + public void testIOExceptionDuringRefreshSuccessfulSubsequentAcquire() throws Exception { + // GIVEN + String refreshErrorMessage = "refresh_error"; + String expectedToken = "expected_token"; + + Response successfulAcquireResponse = buildResponse( + 200, + "{ \"access_token\":\"" + expectedToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }" + ); + + AtomicInteger callCounter = new AtomicInteger(0); + when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> { + callCounter.incrementAndGet(); + + if (callCounter.get() == 1) { + return buildSuccessfulInitResponse(); + } else if (callCounter.get() == 2) { + throw new IOException(refreshErrorMessage); + } else if (callCounter.get() == 3) { + return successfulAcquireResponse; + } + + throw new IllegalStateException("Test improperly defined mock HTTP responses."); + }); + + // Get a good accessDetails so we can have a refresh a second time + testSubject.getAccessDetails(); + + // WHEN + String actualToken = testSubject.getAccessDetails().getAccessToken(); + + // THEN + checkLoggedDebugWhenRefreshFails(); + + checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage))); + + assertEquals(expectedToken, actualToken); + } + + @Test + public void testHTTPErrorDuringRefreshAndSubsequentAcquire() throws Exception { + // GIVEN + String errorRefreshResponseBody = "{ \"error_response\":\"refresh_error\" }"; + String errorAcquireResponseBody = "{ \"error_response\":\"acquire_error\" }"; + + Response errorRefreshResponse = buildResponse(500, errorRefreshResponseBody); + Response errorAcquireResponse = buildResponse(503, errorAcquireResponseBody); + + AtomicInteger callCounter = new AtomicInteger(0); + when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> { + callCounter.incrementAndGet(); + + if (callCounter.get() == 1) { + return buildSuccessfulInitResponse(); + } else if (callCounter.get() == 2) { + return errorRefreshResponse; + } else if (callCounter.get() == 3) { + return errorAcquireResponse; + } + + throw new IllegalStateException("Test improperly defined mock HTTP responses."); + }); + + List expectedLoggedError = Arrays.asList( + String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, errorRefreshResponseBody), + String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 503, errorAcquireResponseBody) + ); + + // Get a good accessDetails so we can have a refresh a second time + testSubject.getAccessDetails(); + + // WHEN + ProcessException actualException = assertThrows( + ProcessException.class, + () -> testSubject.getAccessDetails() + ); + + // THEN + checkLoggedDebugWhenRefreshFails(); + + checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]")); + + checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(expectedLoggedError); + + checkError(new ProcessException("OAuth2 access token request failed [HTTP 503]"), actualException); + } + + @Test + public void testHTTPErrorDuringRefreshSuccessfulSubsequentAcquire() throws Exception { + // GIVEN + String expectedRefreshErrorResponse = "{ \"error_response\":\"refresh_error\" }"; + String expectedToken = "expected_token"; + + Response errorRefreshResponse = buildResponse(500, expectedRefreshErrorResponse); + Response successfulAcquireResponse = buildResponse( + 200, + "{ \"access_token\":\"" + expectedToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }" + ); + + AtomicInteger callCounter = new AtomicInteger(0); + when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> { + callCounter.incrementAndGet(); + + if (callCounter.get() == 1) { + return buildSuccessfulInitResponse(); + } else if (callCounter.get() == 2) { + return errorRefreshResponse; + } else if (callCounter.get() == 3) { + return successfulAcquireResponse; + } + + throw new IllegalStateException("Test improperly defined mock HTTP responses."); + }); + + List expectedLoggedError = Arrays.asList(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, expectedRefreshErrorResponse)); + + // Get a good accessDetails so we can have a refresh a second time + testSubject.getAccessDetails(); + + // WHEN + String actualToken = testSubject.getAccessDetails().getAccessToken(); + + // THEN + checkLoggedDebugWhenRefreshFails(); + + checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]")); + + checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(expectedLoggedError); + + assertEquals(expectedToken, actualToken); + } + + private Response buildSuccessfulInitResponse() { + return buildResponse( + 200, + "{ \"access_token\":\"exists_but_value_irrelevant\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }" + ); + } + + private Response buildResponse(int code, String body) { + return new Response.Builder() + .request(new Request.Builder() + .url("http://unimportant_but_required") + .build() + ) + .protocol(Protocol.HTTP_2) + .message("unimportant_but_required") + .code(code) + .body(ResponseBody.create( + body.getBytes(), + MediaType.parse("application/json")) + ) + .build(); + } + + private void checkLoggedDebugWhenRefreshFails() { + verify(mockLogger, times(3)).debug(debugCaptor.capture()); + List actualDebugMessages = debugCaptor.getAllValues(); + + assertEquals( + Arrays.asList("Getting a new access token", "Refreshing access token", "Getting a new access token"), + actualDebugMessages + ); + } + + private void checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(List expectedLoggedError) { + verify(mockLogger, times(expectedLoggedError.size())).error(errorCaptor.capture()); + List actualLoggedError = errorCaptor.getAllValues(); + + assertEquals(expectedLoggedError, actualLoggedError); + } + + private void checkLoggedRefreshError(Throwable expectedRefreshError) { + verify(mockLogger).info(eq("Couldn't refresh access token"), throwableCaptor.capture()); + Throwable actualRefreshError = throwableCaptor.getValue(); + + checkError(expectedRefreshError, actualRefreshError); + } + + private void checkError(Throwable expectedError, Throwable actualError) { + assertEquals(expectedError.getClass(), actualError.getClass()); + assertEquals(expectedError.getMessage(), actualError.getMessage()); + if (expectedError.getCause() != null || actualError.getCause() != null) { + assertEquals(expectedError.getCause().getClass(), actualError.getCause().getClass()); + assertEquals(expectedError.getCause().getMessage(), actualError.getCause().getMessage()); + } + } +}