mirror of https://github.com/apache/nifi.git
NIFI-9065 Add support for OAuth2AccessTokenProvider in InvokeHTTP
Signed-off-by: Pierre Villard <pierre.villard.fr@gmail.com> This closes #5319.
This commit is contained in:
parent
e603b0179b
commit
aa61494fc3
|
@ -374,6 +374,11 @@
|
|||
<artifactId>nifi-database-utils</artifactId>
|
||||
<version>1.16.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.nifi</groupId>
|
||||
<artifactId>nifi-oauth2-provider-api</artifactId>
|
||||
<version>1.16.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.sshd</groupId>
|
||||
<artifactId>sshd-core</artifactId>
|
||||
|
|
|
@ -42,6 +42,7 @@ import java.util.HashSet;
|
|||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
@ -90,6 +91,7 @@ import org.apache.nifi.expression.ExpressionLanguageScope;
|
|||
import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.flowfile.attributes.CoreAttributes;
|
||||
import org.apache.nifi.logging.ComponentLog;
|
||||
import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
|
||||
import org.apache.nifi.processor.AbstractProcessor;
|
||||
import org.apache.nifi.processor.DataUnit;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
|
@ -494,6 +496,13 @@ public class InvokeHTTP extends AbstractProcessor {
|
|||
.allowableValues("True", "False")
|
||||
.build();
|
||||
|
||||
public static final PropertyDescriptor OAUTH2_ACCESS_TOKEN_PROVIDER = new PropertyDescriptor.Builder()
|
||||
.name("oauth2-access-token-provider")
|
||||
.displayName("OAuth2 Access Token provider")
|
||||
.identifiesControllerService(OAuth2AccessTokenProvider.class)
|
||||
.required(false)
|
||||
.build();
|
||||
|
||||
public static final PropertyDescriptor FLOW_FILE_NAMING_STRATEGY = new PropertyDescriptor.Builder()
|
||||
.name("flow-file-naming-strategy")
|
||||
.description("Determines the strategy used for setting the filename attribute of the FlowFile.")
|
||||
|
@ -527,6 +536,7 @@ public class InvokeHTTP extends AbstractProcessor {
|
|||
PROP_USERAGENT,
|
||||
PROP_BASIC_AUTH_USERNAME,
|
||||
PROP_BASIC_AUTH_PASSWORD,
|
||||
OAUTH2_ACCESS_TOKEN_PROVIDER,
|
||||
PROXY_CONFIGURATION_SERVICE,
|
||||
PROP_PROXY_HOST,
|
||||
PROP_PROXY_PORT,
|
||||
|
@ -595,6 +605,8 @@ public class InvokeHTTP extends AbstractProcessor {
|
|||
|
||||
private volatile boolean useChunked = false;
|
||||
|
||||
private volatile Optional<OAuth2AccessTokenProvider> oauth2AccessTokenProviderOptional;
|
||||
|
||||
private final AtomicReference<OkHttpClient> okHttpClientAtomicReference = new AtomicReference<>();
|
||||
|
||||
@Override
|
||||
|
@ -728,6 +740,19 @@ public class InvokeHTTP extends AbstractProcessor {
|
|||
.build());
|
||||
}
|
||||
|
||||
boolean usingUserNamePasswordAuthorization = validationContext.getProperty(PROP_BASIC_AUTH_USERNAME).isSet()
|
||||
|| validationContext.getProperty(PROP_BASIC_AUTH_PASSWORD).isSet();
|
||||
|
||||
boolean usingOAuth2Authorization = validationContext.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet();
|
||||
|
||||
if (usingUserNamePasswordAuthorization && usingOAuth2Authorization) {
|
||||
results.add(new ValidationResult.Builder()
|
||||
.subject("Authorization properties")
|
||||
.valid(false)
|
||||
.explanation("OAuth2 Authorization cannot be configured together with Username and Password properties")
|
||||
.build());
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
@ -806,6 +831,19 @@ public class InvokeHTTP extends AbstractProcessor {
|
|||
okHttpClientAtomicReference.set(okHttpClientBuilder.build());
|
||||
}
|
||||
|
||||
@OnScheduled
|
||||
public void initOauth2AccessTokenProvider(final ProcessContext context) {
|
||||
if (context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet()) {
|
||||
OAuth2AccessTokenProvider oauth2AccessTokenProvider = context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).asControllerService(OAuth2AccessTokenProvider.class);
|
||||
|
||||
oauth2AccessTokenProvider.getAccessDetails();
|
||||
|
||||
oauth2AccessTokenProviderOptional = Optional.of(oauth2AccessTokenProvider);
|
||||
} else {
|
||||
oauth2AccessTokenProviderOptional = Optional.empty();
|
||||
}
|
||||
}
|
||||
|
||||
private void setAuthenticator(OkHttpClient.Builder okHttpClientBuilder, ProcessContext context) {
|
||||
final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue());
|
||||
|
||||
|
@ -1034,11 +1072,17 @@ public class InvokeHTTP extends AbstractProcessor {
|
|||
final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue());
|
||||
|
||||
// If the username/password properties are set then check if digest auth is being used
|
||||
if (!authUser.isEmpty() && "false".equalsIgnoreCase(context.getProperty(PROP_DIGEST_AUTH).getValue())) {
|
||||
final String authPass = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_PASSWORD).getValue());
|
||||
if ("false".equalsIgnoreCase(context.getProperty(PROP_DIGEST_AUTH).getValue())) {
|
||||
if (!authUser.isEmpty()) {
|
||||
final String authPass = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_PASSWORD).getValue());
|
||||
|
||||
String credential = Credentials.basic(authUser, authPass);
|
||||
requestBuilder.header("Authorization", credential);
|
||||
String credential = Credentials.basic(authUser, authPass);
|
||||
requestBuilder.header("Authorization", credential);
|
||||
} else {
|
||||
oauth2AccessTokenProviderOptional.ifPresent(oauth2AccessTokenProvider ->
|
||||
requestBuilder.addHeader("Authorization", "Bearer " + oauth2AccessTokenProvider.getAccessDetails().getAccessToken())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// set the request method
|
||||
|
|
|
@ -21,6 +21,7 @@ import okhttp3.mockwebserver.MockWebServer;
|
|||
import okhttp3.mockwebserver.RecordedRequest;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.nifi.flowfile.attributes.CoreAttributes;
|
||||
import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
|
||||
import org.apache.nifi.processor.Relationship;
|
||||
import org.apache.nifi.processors.standard.http.FlowFileNamingStrategy;
|
||||
import org.apache.nifi.reporting.InitializationException;
|
||||
|
@ -34,6 +35,7 @@ import org.apache.nifi.util.MockFlowFile;
|
|||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
import org.apache.nifi.web.util.ssl.SslContextUtils;
|
||||
import org.mockito.Answers;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
|
@ -774,6 +776,99 @@ public class InvokeHTTPTest {
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testValidWhenOAuth2Set() throws Exception {
|
||||
// GIVEN
|
||||
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
|
||||
|
||||
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
|
||||
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
|
||||
|
||||
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
|
||||
runner.enableControllerService(oauth2AccessTokenProvider);
|
||||
|
||||
setUrlProperty();
|
||||
|
||||
// WHEN
|
||||
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
|
||||
|
||||
// THEN
|
||||
runner.assertValid();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidWhenOAuth2AndUserNameSet() throws Exception {
|
||||
// GIVEN
|
||||
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
|
||||
|
||||
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
|
||||
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
|
||||
|
||||
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
|
||||
runner.enableControllerService(oauth2AccessTokenProvider);
|
||||
|
||||
setUrlProperty();
|
||||
|
||||
// WHEN
|
||||
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
|
||||
runner.setProperty(InvokeHTTP.PROP_BASIC_AUTH_USERNAME, "userName");
|
||||
|
||||
// THEN
|
||||
runner.assertNotValid();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidWhenOAuth2AndPasswordSet() throws Exception {
|
||||
// GIVEN
|
||||
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
|
||||
|
||||
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
|
||||
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
|
||||
|
||||
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
|
||||
runner.enableControllerService(oauth2AccessTokenProvider);
|
||||
|
||||
setUrlProperty();
|
||||
|
||||
// WHEN
|
||||
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
|
||||
runner.setProperty(InvokeHTTP.PROP_BASIC_AUTH_PASSWORD, "password");
|
||||
|
||||
// THEN
|
||||
runner.assertNotValid();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOAuth2AuthorizationHeader() throws Exception {
|
||||
// GIVEN
|
||||
String accessToken = "access_token";
|
||||
|
||||
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
|
||||
|
||||
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
|
||||
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
|
||||
when(oauth2AccessTokenProvider.getAccessDetails().getAccessToken()).thenReturn(accessToken);
|
||||
|
||||
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
|
||||
runner.enableControllerService(oauth2AccessTokenProvider);
|
||||
|
||||
setUrlProperty();
|
||||
|
||||
mockWebServer.enqueue(new MockResponse());
|
||||
|
||||
// WHEN
|
||||
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
|
||||
runner.enqueue("unimportant");
|
||||
runner.run();
|
||||
|
||||
// THEN
|
||||
RecordedRequest recordedRequest = mockWebServer.takeRequest();
|
||||
|
||||
String actualAuthorizationHeader = recordedRequest.getHeader("Authorization");
|
||||
assertEquals("Bearer " + accessToken, actualAuthorizationHeader);
|
||||
|
||||
}
|
||||
|
||||
private void setUrlProperty() {
|
||||
runner.setProperty(InvokeHTTP.PROP_URL, getMockWebServerUrl());
|
||||
}
|
||||
|
|
|
@ -17,53 +17,80 @@
|
|||
|
||||
package org.apache.nifi.oauth2;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
|
||||
public class AccessToken {
|
||||
private String accessToken;
|
||||
private String refreshToken;
|
||||
private String tokenType;
|
||||
private Integer expires;
|
||||
private String scope;
|
||||
private long expiresIn;
|
||||
private String scopes;
|
||||
|
||||
private Long fetchTime;
|
||||
private final Instant fetchTime;
|
||||
|
||||
public AccessToken(String accessToken,
|
||||
String refreshToken,
|
||||
String tokenType,
|
||||
Integer expires,
|
||||
String scope) {
|
||||
public static final int EXPIRY_MARGIN = 5000;
|
||||
|
||||
public AccessToken() {
|
||||
this.fetchTime = Instant.now();
|
||||
}
|
||||
|
||||
public AccessToken(String accessToken, String refreshToken, String tokenType, long expiresIn, String scopes) {
|
||||
this();
|
||||
this.accessToken = accessToken;
|
||||
this.tokenType = tokenType;
|
||||
this.refreshToken = refreshToken;
|
||||
this.expires = expires;
|
||||
this.scope = scope;
|
||||
this.fetchTime = System.currentTimeMillis();
|
||||
this.tokenType = tokenType;
|
||||
this.expiresIn = expiresIn;
|
||||
this.scopes = scopes;
|
||||
}
|
||||
|
||||
public String getAccessToken() {
|
||||
return accessToken;
|
||||
}
|
||||
|
||||
public void setAccessToken(String accessToken) {
|
||||
this.accessToken = accessToken;
|
||||
}
|
||||
|
||||
public String getRefreshToken() {
|
||||
return refreshToken;
|
||||
}
|
||||
|
||||
public void setRefreshToken(String refreshToken) {
|
||||
this.refreshToken = refreshToken;
|
||||
}
|
||||
|
||||
public String getTokenType() {
|
||||
return tokenType;
|
||||
}
|
||||
|
||||
public Integer getExpires() {
|
||||
return expires;
|
||||
public void setTokenType(String tokenType) {
|
||||
this.tokenType = tokenType;
|
||||
}
|
||||
|
||||
public String getScope() {
|
||||
return scope;
|
||||
public long getExpiresIn() {
|
||||
return expiresIn;
|
||||
}
|
||||
|
||||
public Long getFetchTime() {
|
||||
public void setExpiresIn(long expiresIn) {
|
||||
this.expiresIn = expiresIn;
|
||||
}
|
||||
|
||||
public String getScopes() {
|
||||
return scopes;
|
||||
}
|
||||
|
||||
public void setScopes(String scopes) {
|
||||
this.scopes = scopes;
|
||||
}
|
||||
|
||||
public Instant getFetchTime() {
|
||||
return fetchTime;
|
||||
}
|
||||
|
||||
public boolean isExpired() {
|
||||
return System.currentTimeMillis() >= ( fetchTime + (expires * 1000) );
|
||||
boolean expired = Duration.between(Instant.now(), fetchTime.plusSeconds(expiresIn - EXPIRY_MARGIN)).isNegative();
|
||||
|
||||
return expired;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.oauth2;
|
||||
|
||||
import org.apache.nifi.controller.ControllerService;
|
||||
|
||||
/**
|
||||
* Controller service that provides OAuth2 access details
|
||||
*/
|
||||
public interface OAuth2AccessTokenProvider extends ControllerService {
|
||||
|
||||
/**
|
||||
* @return A valid access token (refreshed automatically if needed) and additional metadata (provided by the OAuth2 access server)
|
||||
*/
|
||||
AccessToken getAccessDetails();
|
||||
}
|
|
@ -29,7 +29,10 @@ import java.util.List;
|
|||
|
||||
/**
|
||||
* Interface for defining a credential-providing controller service for oauth2 processes.
|
||||
*
|
||||
* @deprecated use {@link OAuth2AccessTokenProvider} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public interface OAuth2TokenProvider extends ControllerService {
|
||||
PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
|
||||
.name("oauth2-ssl-context")
|
||||
|
|
|
@ -30,6 +30,7 @@ import okhttp3.Request;
|
|||
import okhttp3.RequestBody;
|
||||
import okhttp3.Response;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
import org.apache.nifi.annotation.documentation.DeprecationNotice;
|
||||
import org.apache.nifi.annotation.documentation.Tags;
|
||||
import org.apache.nifi.annotation.lifecycle.OnEnabled;
|
||||
import org.apache.nifi.components.PropertyDescriptor;
|
||||
|
@ -39,6 +40,8 @@ import org.apache.nifi.processor.exception.ProcessException;
|
|||
import org.apache.nifi.ssl.SSLContextService;
|
||||
import org.apache.nifi.util.StringUtils;
|
||||
|
||||
@Deprecated
|
||||
@DeprecationNotice(alternatives = {StandardOauth2AccessTokenProvider.class})
|
||||
@Tags({"oauth2", "provider", "authorization" })
|
||||
@CapabilityDescription("This controller service provides a way of working with access and refresh tokens via the " +
|
||||
"password and client_credential grant flows in the OAuth2 specification. It is meant to provide a way for components " +
|
||||
|
|
|
@ -0,0 +1,300 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.oauth2;
|
||||
|
||||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
|
||||
import okhttp3.FormBody;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.RequestBody;
|
||||
import okhttp3.Response;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
import org.apache.nifi.annotation.documentation.Tags;
|
||||
import org.apache.nifi.annotation.lifecycle.OnDisabled;
|
||||
import org.apache.nifi.annotation.lifecycle.OnEnabled;
|
||||
import org.apache.nifi.components.AllowableValue;
|
||||
import org.apache.nifi.components.PropertyDescriptor;
|
||||
import org.apache.nifi.components.ValidationContext;
|
||||
import org.apache.nifi.components.ValidationResult;
|
||||
import org.apache.nifi.components.Validator;
|
||||
import org.apache.nifi.controller.AbstractControllerService;
|
||||
import org.apache.nifi.controller.ConfigurationContext;
|
||||
import org.apache.nifi.expression.ExpressionLanguageScope;
|
||||
import org.apache.nifi.processor.exception.ProcessException;
|
||||
import org.apache.nifi.processor.util.StandardValidators;
|
||||
import org.apache.nifi.ssl.SSLContextService;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@Tags({"oauth2", "provider", "authorization", "access token", "http"})
|
||||
@CapabilityDescription("Provides OAuth 2.0 access tokens that can be used as Bearer authorization header in HTTP requests." +
|
||||
" Uses Resource Owner Password Credentials Grant.")
|
||||
public class StandardOauth2AccessTokenProvider extends AbstractControllerService implements OAuth2AccessTokenProvider {
|
||||
public static final PropertyDescriptor AUTHORIZATION_SERVER_URL = new PropertyDescriptor.Builder()
|
||||
.name("authorization-server-url")
|
||||
.displayName("Authorization Server URL")
|
||||
.description("The URL of the authorization server that issues access tokens.")
|
||||
.required(true)
|
||||
.addValidator(StandardValidators.URL_VALIDATOR)
|
||||
.expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
|
||||
.build();
|
||||
|
||||
public static AllowableValue RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE = new AllowableValue(
|
||||
"password",
|
||||
"User Password",
|
||||
"Resource Owner Password Credentials Grant. Used to access resources available to users. Requires username and password and usually Client ID and Client Secret"
|
||||
);
|
||||
public static AllowableValue CLIENT_CREDENTIALS_GRANT_TYPE = new AllowableValue(
|
||||
"client_credentials",
|
||||
"Client Credentials",
|
||||
"Client Credentials Grant. Used to access resources available to clients. Requires Client ID and Client Secret"
|
||||
);
|
||||
|
||||
public static final PropertyDescriptor GRANT_TYPE = new PropertyDescriptor.Builder()
|
||||
.name("grant-type")
|
||||
.displayName("Grant Type")
|
||||
.description("The OAuth2 Grant Type to be used when acquiring an access token.")
|
||||
.required(true)
|
||||
.allowableValues(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE, CLIENT_CREDENTIALS_GRANT_TYPE)
|
||||
.defaultValue(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue())
|
||||
.build();
|
||||
|
||||
public static final PropertyDescriptor USERNAME = new PropertyDescriptor.Builder()
|
||||
.name("service-user-name")
|
||||
.displayName("Username")
|
||||
.description("Username on the service that is being accessed.")
|
||||
.dependsOn(GRANT_TYPE, RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE)
|
||||
.required(true)
|
||||
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
|
||||
.expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
|
||||
.build();
|
||||
|
||||
public static final PropertyDescriptor PASSWORD = new PropertyDescriptor.Builder()
|
||||
.name("service-password")
|
||||
.displayName("Password")
|
||||
.description("Password for the username on the service that is being accessed.")
|
||||
.dependsOn(GRANT_TYPE, RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE)
|
||||
.required(true)
|
||||
.sensitive(true)
|
||||
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
|
||||
.build();
|
||||
|
||||
public static final PropertyDescriptor CLIENT_ID = new PropertyDescriptor.Builder()
|
||||
.name("client-id")
|
||||
.displayName("Client ID")
|
||||
.required(false)
|
||||
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
|
||||
.expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
|
||||
.build();
|
||||
|
||||
public static final PropertyDescriptor CLIENT_SECRET = new PropertyDescriptor.Builder()
|
||||
.name("client-secret")
|
||||
.displayName("Client secret")
|
||||
.dependsOn(CLIENT_ID)
|
||||
.required(true)
|
||||
.sensitive(true)
|
||||
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
|
||||
.build();
|
||||
|
||||
public static final PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
|
||||
.name("ssl-context-service")
|
||||
.displayName("SSL Context Servuce")
|
||||
.addValidator(Validator.VALID)
|
||||
.identifiesControllerService(SSLContextService.class)
|
||||
.required(false)
|
||||
.build();
|
||||
|
||||
private static final List<PropertyDescriptor> PROPERTIES = Collections.unmodifiableList(Arrays.asList(
|
||||
AUTHORIZATION_SERVER_URL,
|
||||
GRANT_TYPE,
|
||||
USERNAME,
|
||||
PASSWORD,
|
||||
CLIENT_ID,
|
||||
CLIENT_SECRET,
|
||||
SSL_CONTEXT
|
||||
));
|
||||
|
||||
public static final ObjectMapper ACCESS_DETAILS_MAPPER = new ObjectMapper()
|
||||
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
|
||||
.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);
|
||||
|
||||
private volatile String authorizationServerUrl;
|
||||
private volatile OkHttpClient httpClient;
|
||||
|
||||
private volatile String grantType;
|
||||
private volatile String username;
|
||||
private volatile String password;
|
||||
private volatile String clientId;
|
||||
private volatile String clientSecret;
|
||||
|
||||
private volatile AccessToken accessDetails;
|
||||
|
||||
@Override
|
||||
public List<PropertyDescriptor> getSupportedPropertyDescriptors() {
|
||||
return PROPERTIES;
|
||||
}
|
||||
|
||||
@OnEnabled
|
||||
public void onEnabled(ConfigurationContext context) {
|
||||
authorizationServerUrl = context.getProperty(AUTHORIZATION_SERVER_URL).evaluateAttributeExpressions().getValue();
|
||||
|
||||
httpClient = createHttpClient(context);
|
||||
|
||||
grantType = context.getProperty(GRANT_TYPE).getValue();
|
||||
username = context.getProperty(USERNAME).evaluateAttributeExpressions().getValue();
|
||||
password = context.getProperty(PASSWORD).getValue();
|
||||
clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
|
||||
clientSecret = context.getProperty(CLIENT_SECRET).getValue();
|
||||
}
|
||||
|
||||
@OnDisabled
|
||||
public void onDisabled() {
|
||||
accessDetails = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Collection<ValidationResult> customValidate(ValidationContext validationContext) {
|
||||
final List<ValidationResult> validationResults = new ArrayList<>(super.customValidate(validationContext));
|
||||
|
||||
if (
|
||||
validationContext.getProperty(GRANT_TYPE).getValue().equals(CLIENT_CREDENTIALS_GRANT_TYPE.getValue())
|
||||
&& !validationContext.getProperty(CLIENT_ID).isSet()
|
||||
) {
|
||||
validationResults.add(new ValidationResult.Builder().subject(CLIENT_ID.getDisplayName())
|
||||
.valid(false)
|
||||
.explanation(String.format(
|
||||
"When '%s' is set to '%s', '%s' is required",
|
||||
GRANT_TYPE.getDisplayName(),
|
||||
CLIENT_CREDENTIALS_GRANT_TYPE.getDisplayName(),
|
||||
CLIENT_ID.getDisplayName())
|
||||
)
|
||||
.build());
|
||||
}
|
||||
|
||||
return validationResults;
|
||||
}
|
||||
|
||||
protected OkHttpClient createHttpClient(ConfigurationContext context) {
|
||||
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder();
|
||||
|
||||
SSLContextService sslService = context.getProperty(SSL_CONTEXT).asControllerService(SSLContextService.class);
|
||||
if (sslService != null) {
|
||||
final X509TrustManager trustManager = sslService.createTrustManager();
|
||||
SSLContext sslContext = sslService.createContext();
|
||||
clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustManager);
|
||||
}
|
||||
|
||||
return clientBuilder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AccessToken getAccessDetails() {
|
||||
if (this.accessDetails == null) {
|
||||
acquireAccessDetails();
|
||||
} else if (this.accessDetails.isExpired()) {
|
||||
if (this.accessDetails.getRefreshToken() == null) {
|
||||
acquireAccessDetails();
|
||||
} else {
|
||||
try {
|
||||
refreshAccessDetails();
|
||||
} catch (Exception e) {
|
||||
getLogger().info("Couldn't refresh access token", e);
|
||||
acquireAccessDetails();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return accessDetails;
|
||||
}
|
||||
|
||||
private void acquireAccessDetails() {
|
||||
getLogger().debug("Getting a new access token");
|
||||
|
||||
FormBody.Builder acquireTokenBuilder = new FormBody.Builder();
|
||||
|
||||
if (grantType.equals(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue())) {
|
||||
acquireTokenBuilder.add("grant_type", "password")
|
||||
.add("username", username)
|
||||
.add("password", password);
|
||||
} else if (grantType.equals(CLIENT_CREDENTIALS_GRANT_TYPE.getValue())) {
|
||||
acquireTokenBuilder.add("grant_type", "client_credentials");
|
||||
}
|
||||
|
||||
if (clientId != null) {
|
||||
acquireTokenBuilder.add("client_id", clientId);
|
||||
acquireTokenBuilder.add("client_secret", clientSecret);
|
||||
}
|
||||
|
||||
RequestBody acquireTokenRequestBody = acquireTokenBuilder.build();
|
||||
|
||||
Request acquireTokenRequest = new Request.Builder()
|
||||
.url(authorizationServerUrl)
|
||||
.post(acquireTokenRequestBody)
|
||||
.build();
|
||||
|
||||
this.accessDetails = getAccessDetails(acquireTokenRequest);
|
||||
}
|
||||
|
||||
private void refreshAccessDetails() {
|
||||
getLogger().debug("Refreshing access token");
|
||||
|
||||
FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
|
||||
.add("grant_type", "refresh_token")
|
||||
.add("refresh_token", this.accessDetails.getRefreshToken());
|
||||
|
||||
if (clientId != null) {
|
||||
refreshTokenBuilder.add("client_id", clientId);
|
||||
refreshTokenBuilder.add("client_secret", clientSecret);
|
||||
}
|
||||
|
||||
RequestBody refreshTokenRequestBody = refreshTokenBuilder.build();
|
||||
|
||||
Request refreshRequest = new Request.Builder()
|
||||
.url(authorizationServerUrl)
|
||||
.post(refreshTokenRequestBody)
|
||||
.build();
|
||||
|
||||
this.accessDetails = getAccessDetails(refreshRequest);
|
||||
}
|
||||
|
||||
private AccessToken getAccessDetails(Request newRequest) {
|
||||
try {
|
||||
Response response = httpClient.newCall(newRequest).execute();
|
||||
String responseBody = response.body().string();
|
||||
if (response.code() != 200) {
|
||||
getLogger().error(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", response.code(), responseBody));
|
||||
throw new ProcessException(String.format("OAuth2 access token request failed [HTTP %d]", response.code()));
|
||||
}
|
||||
|
||||
AccessToken accessDetails = ACCESS_DETAILS_MAPPER.readValue(responseBody, AccessToken.class);
|
||||
|
||||
return accessDetails;
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException("OAuth2 access token request failed", e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -13,3 +13,5 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
org.apache.nifi.oauth2.OAuth2TokenProviderImpl
|
||||
org.apache.nifi.oauth2.StandardOauth2AccessTokenProvider
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ public class OAuth2TokenProviderImplTest {
|
|||
private void assertAccessTokenFound(final AccessToken accessToken) {
|
||||
assertNotNull(accessToken);
|
||||
assertEquals("access token", accessToken.getAccessToken());
|
||||
assertEquals(300, accessToken.getExpires().intValue());
|
||||
assertEquals(5300, accessToken.getExpiresIn());
|
||||
assertEquals("BEARER", accessToken.getTokenType());
|
||||
assertFalse(accessToken.isExpired());
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ public class OAuth2TokenProviderImplTest {
|
|||
token.put("access_token", "access token");
|
||||
token.put("refresh_token", "refresh token");
|
||||
token.put("token_type", "BEARER");
|
||||
token.put("expires_in", 300);
|
||||
token.put("expires_in", 5300);
|
||||
token.put("scope", "test scope");
|
||||
final String accessToken = new ObjectMapper().writeValueAsString(token);
|
||||
mockWebServer.enqueue(new MockResponse().setResponseCode(200).addHeader("Content-Type", "application/json").setBody(accessToken));
|
||||
|
|
|
@ -0,0 +1,411 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.oauth2;
|
||||
|
||||
import okhttp3.MediaType;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Protocol;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.Response;
|
||||
import okhttp3.ResponseBody;
|
||||
import org.apache.nifi.controller.ConfigurationContext;
|
||||
import org.apache.nifi.logging.ComponentLog;
|
||||
import org.apache.nifi.processor.Processor;
|
||||
import org.apache.nifi.processor.exception.ProcessException;
|
||||
import org.apache.nifi.util.NoOpProcessor;
|
||||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.mockito.Answers;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.MockitoAnnotations;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class StandardOauth2AccessTokenProviderTest {
|
||||
private static final String AUTHORIZATION_SERVER_URL = "http://authorizationServerUrl";
|
||||
private static final String USERNAME = "username";
|
||||
private static final String PASSWORD = "password";
|
||||
private static final String CLIENT_ID = "clientId";
|
||||
private static final String CLIENT_SECRET = "clientSecret";
|
||||
|
||||
private StandardOauth2AccessTokenProvider testSubject;
|
||||
|
||||
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
|
||||
private OkHttpClient mockHttpClient;
|
||||
|
||||
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
|
||||
private ConfigurationContext mockContext;
|
||||
|
||||
@Mock
|
||||
private ComponentLog mockLogger;
|
||||
@Captor
|
||||
private ArgumentCaptor<String> debugCaptor;
|
||||
@Captor
|
||||
private ArgumentCaptor<String> errorCaptor;
|
||||
@Captor
|
||||
private ArgumentCaptor<Throwable> throwableCaptor;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
MockitoAnnotations.initMocks(this);
|
||||
|
||||
testSubject = new StandardOauth2AccessTokenProvider() {
|
||||
@Override
|
||||
protected OkHttpClient createHttpClient(ConfigurationContext context) {
|
||||
return mockHttpClient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ComponentLog getLogger() {
|
||||
return mockLogger;
|
||||
}
|
||||
};
|
||||
|
||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue());
|
||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL).evaluateAttributeExpressions().getValue()).thenReturn(AUTHORIZATION_SERVER_URL);
|
||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.USERNAME).evaluateAttributeExpressions().getValue()).thenReturn(USERNAME);
|
||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
|
||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
|
||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
|
||||
|
||||
testSubject.onEnabled(mockContext);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidWhenClientCredentialsGrantTypeSetWithoutClientId() throws Exception {
|
||||
// GIVEN
|
||||
Processor processor = new NoOpProcessor();
|
||||
TestRunner runner = TestRunners.newTestRunner(processor);
|
||||
|
||||
runner.addControllerService("testSubject", testSubject);
|
||||
|
||||
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
|
||||
|
||||
// WHEN
|
||||
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
|
||||
|
||||
// THEN
|
||||
runner.assertNotValid(testSubject);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testValidWhenClientCredentialsGrantTypeSetWithClientId() throws Exception {
|
||||
// GIVEN
|
||||
Processor processor = new NoOpProcessor();
|
||||
TestRunner runner = TestRunners.newTestRunner(processor);
|
||||
|
||||
runner.addControllerService("testSubject", testSubject);
|
||||
|
||||
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
|
||||
|
||||
// WHEN
|
||||
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
|
||||
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, "clientId");
|
||||
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, "clientSecret");
|
||||
|
||||
// THEN
|
||||
runner.assertValid(testSubject);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAcquireNewToken() throws Exception {
|
||||
String accessTokenValue = "access_token_value";
|
||||
|
||||
// GIVEN
|
||||
Response response = buildResponse(
|
||||
200,
|
||||
"{ \"access_token\":\"" + accessTokenValue + "\" }"
|
||||
);
|
||||
|
||||
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
|
||||
|
||||
// WHEN
|
||||
String actual = testSubject.getAccessDetails().getAccessToken();
|
||||
|
||||
// THEN
|
||||
assertEquals(accessTokenValue, actual);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRefreshToken() throws Exception {
|
||||
// GIVEN
|
||||
String firstToken = "first_token";
|
||||
String expectedToken = "second_token";
|
||||
|
||||
Response response1 = buildResponse(
|
||||
200,
|
||||
"{ \"access_token\":\"" + firstToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
|
||||
);
|
||||
|
||||
Response response2 = buildResponse(
|
||||
200,
|
||||
"{ \"access_token\":\"" + expectedToken + "\" }"
|
||||
);
|
||||
|
||||
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
|
||||
|
||||
// WHEN
|
||||
testSubject.getAccessDetails();
|
||||
String actualToken = testSubject.getAccessDetails().getAccessToken();
|
||||
|
||||
// THEN
|
||||
assertEquals(expectedToken, actualToken);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception {
|
||||
// GIVEN
|
||||
String refreshErrorMessage = "refresh_error";
|
||||
String acquireErrorMessage = "acquire_error";
|
||||
|
||||
AtomicInteger callCounter = new AtomicInteger(0);
|
||||
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
|
||||
callCounter.incrementAndGet();
|
||||
|
||||
if (callCounter.get() == 1) {
|
||||
return buildSuccessfulInitResponse();
|
||||
} else if (callCounter.get() == 2) {
|
||||
throw new IOException(refreshErrorMessage);
|
||||
} else if (callCounter.get() == 3) {
|
||||
throw new IOException(acquireErrorMessage);
|
||||
}
|
||||
|
||||
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
|
||||
});
|
||||
|
||||
// Get a good accessDetails so we can have a refresh a second time
|
||||
testSubject.getAccessDetails();
|
||||
|
||||
// WHEN
|
||||
UncheckedIOException actualException = assertThrows(
|
||||
UncheckedIOException.class,
|
||||
() -> testSubject.getAccessDetails()
|
||||
);
|
||||
|
||||
// THEN
|
||||
checkLoggedDebugWhenRefreshFails();
|
||||
|
||||
checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
|
||||
|
||||
checkError(new UncheckedIOException("OAuth2 access token request failed", new IOException(acquireErrorMessage)), actualException);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIOExceptionDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
|
||||
// GIVEN
|
||||
String refreshErrorMessage = "refresh_error";
|
||||
String expectedToken = "expected_token";
|
||||
|
||||
Response successfulAcquireResponse = buildResponse(
|
||||
200,
|
||||
"{ \"access_token\":\"" + expectedToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
|
||||
);
|
||||
|
||||
AtomicInteger callCounter = new AtomicInteger(0);
|
||||
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
|
||||
callCounter.incrementAndGet();
|
||||
|
||||
if (callCounter.get() == 1) {
|
||||
return buildSuccessfulInitResponse();
|
||||
} else if (callCounter.get() == 2) {
|
||||
throw new IOException(refreshErrorMessage);
|
||||
} else if (callCounter.get() == 3) {
|
||||
return successfulAcquireResponse;
|
||||
}
|
||||
|
||||
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
|
||||
});
|
||||
|
||||
// Get a good accessDetails so we can have a refresh a second time
|
||||
testSubject.getAccessDetails();
|
||||
|
||||
// WHEN
|
||||
String actualToken = testSubject.getAccessDetails().getAccessToken();
|
||||
|
||||
// THEN
|
||||
checkLoggedDebugWhenRefreshFails();
|
||||
|
||||
checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
|
||||
|
||||
assertEquals(expectedToken, actualToken);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHTTPErrorDuringRefreshAndSubsequentAcquire() throws Exception {
|
||||
// GIVEN
|
||||
String errorRefreshResponseBody = "{ \"error_response\":\"refresh_error\" }";
|
||||
String errorAcquireResponseBody = "{ \"error_response\":\"acquire_error\" }";
|
||||
|
||||
Response errorRefreshResponse = buildResponse(500, errorRefreshResponseBody);
|
||||
Response errorAcquireResponse = buildResponse(503, errorAcquireResponseBody);
|
||||
|
||||
AtomicInteger callCounter = new AtomicInteger(0);
|
||||
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
|
||||
callCounter.incrementAndGet();
|
||||
|
||||
if (callCounter.get() == 1) {
|
||||
return buildSuccessfulInitResponse();
|
||||
} else if (callCounter.get() == 2) {
|
||||
return errorRefreshResponse;
|
||||
} else if (callCounter.get() == 3) {
|
||||
return errorAcquireResponse;
|
||||
}
|
||||
|
||||
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
|
||||
});
|
||||
|
||||
List<String> expectedLoggedError = Arrays.asList(
|
||||
String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, errorRefreshResponseBody),
|
||||
String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 503, errorAcquireResponseBody)
|
||||
);
|
||||
|
||||
// Get a good accessDetails so we can have a refresh a second time
|
||||
testSubject.getAccessDetails();
|
||||
|
||||
// WHEN
|
||||
ProcessException actualException = assertThrows(
|
||||
ProcessException.class,
|
||||
() -> testSubject.getAccessDetails()
|
||||
);
|
||||
|
||||
// THEN
|
||||
checkLoggedDebugWhenRefreshFails();
|
||||
|
||||
checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
|
||||
|
||||
checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(expectedLoggedError);
|
||||
|
||||
checkError(new ProcessException("OAuth2 access token request failed [HTTP 503]"), actualException);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHTTPErrorDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
|
||||
// GIVEN
|
||||
String expectedRefreshErrorResponse = "{ \"error_response\":\"refresh_error\" }";
|
||||
String expectedToken = "expected_token";
|
||||
|
||||
Response errorRefreshResponse = buildResponse(500, expectedRefreshErrorResponse);
|
||||
Response successfulAcquireResponse = buildResponse(
|
||||
200,
|
||||
"{ \"access_token\":\"" + expectedToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
|
||||
);
|
||||
|
||||
AtomicInteger callCounter = new AtomicInteger(0);
|
||||
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
|
||||
callCounter.incrementAndGet();
|
||||
|
||||
if (callCounter.get() == 1) {
|
||||
return buildSuccessfulInitResponse();
|
||||
} else if (callCounter.get() == 2) {
|
||||
return errorRefreshResponse;
|
||||
} else if (callCounter.get() == 3) {
|
||||
return successfulAcquireResponse;
|
||||
}
|
||||
|
||||
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
|
||||
});
|
||||
|
||||
List<String> expectedLoggedError = Arrays.asList(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, expectedRefreshErrorResponse));
|
||||
|
||||
// Get a good accessDetails so we can have a refresh a second time
|
||||
testSubject.getAccessDetails();
|
||||
|
||||
// WHEN
|
||||
String actualToken = testSubject.getAccessDetails().getAccessToken();
|
||||
|
||||
// THEN
|
||||
checkLoggedDebugWhenRefreshFails();
|
||||
|
||||
checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
|
||||
|
||||
checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(expectedLoggedError);
|
||||
|
||||
assertEquals(expectedToken, actualToken);
|
||||
}
|
||||
|
||||
private Response buildSuccessfulInitResponse() {
|
||||
return buildResponse(
|
||||
200,
|
||||
"{ \"access_token\":\"exists_but_value_irrelevant\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
|
||||
);
|
||||
}
|
||||
|
||||
private Response buildResponse(int code, String body) {
|
||||
return new Response.Builder()
|
||||
.request(new Request.Builder()
|
||||
.url("http://unimportant_but_required")
|
||||
.build()
|
||||
)
|
||||
.protocol(Protocol.HTTP_2)
|
||||
.message("unimportant_but_required")
|
||||
.code(code)
|
||||
.body(ResponseBody.create(
|
||||
body.getBytes(),
|
||||
MediaType.parse("application/json"))
|
||||
)
|
||||
.build();
|
||||
}
|
||||
|
||||
private void checkLoggedDebugWhenRefreshFails() {
|
||||
verify(mockLogger, times(3)).debug(debugCaptor.capture());
|
||||
List<String> actualDebugMessages = debugCaptor.getAllValues();
|
||||
|
||||
assertEquals(
|
||||
Arrays.asList("Getting a new access token", "Refreshing access token", "Getting a new access token"),
|
||||
actualDebugMessages
|
||||
);
|
||||
}
|
||||
|
||||
private void checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(List<String> expectedLoggedError) {
|
||||
verify(mockLogger, times(expectedLoggedError.size())).error(errorCaptor.capture());
|
||||
List<String> actualLoggedError = errorCaptor.getAllValues();
|
||||
|
||||
assertEquals(expectedLoggedError, actualLoggedError);
|
||||
}
|
||||
|
||||
private void checkLoggedRefreshError(Throwable expectedRefreshError) {
|
||||
verify(mockLogger).info(eq("Couldn't refresh access token"), throwableCaptor.capture());
|
||||
Throwable actualRefreshError = throwableCaptor.getValue();
|
||||
|
||||
checkError(expectedRefreshError, actualRefreshError);
|
||||
}
|
||||
|
||||
private void checkError(Throwable expectedError, Throwable actualError) {
|
||||
assertEquals(expectedError.getClass(), actualError.getClass());
|
||||
assertEquals(expectedError.getMessage(), actualError.getMessage());
|
||||
if (expectedError.getCause() != null || actualError.getCause() != null) {
|
||||
assertEquals(expectedError.getCause().getClass(), actualError.getCause().getClass());
|
||||
assertEquals(expectedError.getCause().getMessage(), actualError.getCause().getMessage());
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue