mirror of https://github.com/apache/nifi.git
NIFI-9065 Add support for OAuth2AccessTokenProvider in InvokeHTTP
Signed-off-by: Pierre Villard <pierre.villard.fr@gmail.com> This closes #5319.
This commit is contained in:
parent
e603b0179b
commit
aa61494fc3
|
@ -374,6 +374,11 @@
|
||||||
<artifactId>nifi-database-utils</artifactId>
|
<artifactId>nifi-database-utils</artifactId>
|
||||||
<version>1.16.0-SNAPSHOT</version>
|
<version>1.16.0-SNAPSHOT</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.nifi</groupId>
|
||||||
|
<artifactId>nifi-oauth2-provider-api</artifactId>
|
||||||
|
<version>1.16.0-SNAPSHOT</version>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.sshd</groupId>
|
<groupId>org.apache.sshd</groupId>
|
||||||
<artifactId>sshd-core</artifactId>
|
<artifactId>sshd-core</artifactId>
|
||||||
|
|
|
@ -42,6 +42,7 @@ import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
@ -90,6 +91,7 @@ import org.apache.nifi.expression.ExpressionLanguageScope;
|
||||||
import org.apache.nifi.flowfile.FlowFile;
|
import org.apache.nifi.flowfile.FlowFile;
|
||||||
import org.apache.nifi.flowfile.attributes.CoreAttributes;
|
import org.apache.nifi.flowfile.attributes.CoreAttributes;
|
||||||
import org.apache.nifi.logging.ComponentLog;
|
import org.apache.nifi.logging.ComponentLog;
|
||||||
|
import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
|
||||||
import org.apache.nifi.processor.AbstractProcessor;
|
import org.apache.nifi.processor.AbstractProcessor;
|
||||||
import org.apache.nifi.processor.DataUnit;
|
import org.apache.nifi.processor.DataUnit;
|
||||||
import org.apache.nifi.processor.ProcessContext;
|
import org.apache.nifi.processor.ProcessContext;
|
||||||
|
@ -494,6 +496,13 @@ public class InvokeHTTP extends AbstractProcessor {
|
||||||
.allowableValues("True", "False")
|
.allowableValues("True", "False")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
public static final PropertyDescriptor OAUTH2_ACCESS_TOKEN_PROVIDER = new PropertyDescriptor.Builder()
|
||||||
|
.name("oauth2-access-token-provider")
|
||||||
|
.displayName("OAuth2 Access Token provider")
|
||||||
|
.identifiesControllerService(OAuth2AccessTokenProvider.class)
|
||||||
|
.required(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
public static final PropertyDescriptor FLOW_FILE_NAMING_STRATEGY = new PropertyDescriptor.Builder()
|
public static final PropertyDescriptor FLOW_FILE_NAMING_STRATEGY = new PropertyDescriptor.Builder()
|
||||||
.name("flow-file-naming-strategy")
|
.name("flow-file-naming-strategy")
|
||||||
.description("Determines the strategy used for setting the filename attribute of the FlowFile.")
|
.description("Determines the strategy used for setting the filename attribute of the FlowFile.")
|
||||||
|
@ -527,6 +536,7 @@ public class InvokeHTTP extends AbstractProcessor {
|
||||||
PROP_USERAGENT,
|
PROP_USERAGENT,
|
||||||
PROP_BASIC_AUTH_USERNAME,
|
PROP_BASIC_AUTH_USERNAME,
|
||||||
PROP_BASIC_AUTH_PASSWORD,
|
PROP_BASIC_AUTH_PASSWORD,
|
||||||
|
OAUTH2_ACCESS_TOKEN_PROVIDER,
|
||||||
PROXY_CONFIGURATION_SERVICE,
|
PROXY_CONFIGURATION_SERVICE,
|
||||||
PROP_PROXY_HOST,
|
PROP_PROXY_HOST,
|
||||||
PROP_PROXY_PORT,
|
PROP_PROXY_PORT,
|
||||||
|
@ -595,6 +605,8 @@ public class InvokeHTTP extends AbstractProcessor {
|
||||||
|
|
||||||
private volatile boolean useChunked = false;
|
private volatile boolean useChunked = false;
|
||||||
|
|
||||||
|
private volatile Optional<OAuth2AccessTokenProvider> oauth2AccessTokenProviderOptional;
|
||||||
|
|
||||||
private final AtomicReference<OkHttpClient> okHttpClientAtomicReference = new AtomicReference<>();
|
private final AtomicReference<OkHttpClient> okHttpClientAtomicReference = new AtomicReference<>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -728,6 +740,19 @@ public class InvokeHTTP extends AbstractProcessor {
|
||||||
.build());
|
.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean usingUserNamePasswordAuthorization = validationContext.getProperty(PROP_BASIC_AUTH_USERNAME).isSet()
|
||||||
|
|| validationContext.getProperty(PROP_BASIC_AUTH_PASSWORD).isSet();
|
||||||
|
|
||||||
|
boolean usingOAuth2Authorization = validationContext.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet();
|
||||||
|
|
||||||
|
if (usingUserNamePasswordAuthorization && usingOAuth2Authorization) {
|
||||||
|
results.add(new ValidationResult.Builder()
|
||||||
|
.subject("Authorization properties")
|
||||||
|
.valid(false)
|
||||||
|
.explanation("OAuth2 Authorization cannot be configured together with Username and Password properties")
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -806,6 +831,19 @@ public class InvokeHTTP extends AbstractProcessor {
|
||||||
okHttpClientAtomicReference.set(okHttpClientBuilder.build());
|
okHttpClientAtomicReference.set(okHttpClientBuilder.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OnScheduled
|
||||||
|
public void initOauth2AccessTokenProvider(final ProcessContext context) {
|
||||||
|
if (context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet()) {
|
||||||
|
OAuth2AccessTokenProvider oauth2AccessTokenProvider = context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).asControllerService(OAuth2AccessTokenProvider.class);
|
||||||
|
|
||||||
|
oauth2AccessTokenProvider.getAccessDetails();
|
||||||
|
|
||||||
|
oauth2AccessTokenProviderOptional = Optional.of(oauth2AccessTokenProvider);
|
||||||
|
} else {
|
||||||
|
oauth2AccessTokenProviderOptional = Optional.empty();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void setAuthenticator(OkHttpClient.Builder okHttpClientBuilder, ProcessContext context) {
|
private void setAuthenticator(OkHttpClient.Builder okHttpClientBuilder, ProcessContext context) {
|
||||||
final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue());
|
final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue());
|
||||||
|
|
||||||
|
@ -1034,11 +1072,17 @@ public class InvokeHTTP extends AbstractProcessor {
|
||||||
final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue());
|
final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue());
|
||||||
|
|
||||||
// If the username/password properties are set then check if digest auth is being used
|
// If the username/password properties are set then check if digest auth is being used
|
||||||
if (!authUser.isEmpty() && "false".equalsIgnoreCase(context.getProperty(PROP_DIGEST_AUTH).getValue())) {
|
if ("false".equalsIgnoreCase(context.getProperty(PROP_DIGEST_AUTH).getValue())) {
|
||||||
|
if (!authUser.isEmpty()) {
|
||||||
final String authPass = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_PASSWORD).getValue());
|
final String authPass = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_PASSWORD).getValue());
|
||||||
|
|
||||||
String credential = Credentials.basic(authUser, authPass);
|
String credential = Credentials.basic(authUser, authPass);
|
||||||
requestBuilder.header("Authorization", credential);
|
requestBuilder.header("Authorization", credential);
|
||||||
|
} else {
|
||||||
|
oauth2AccessTokenProviderOptional.ifPresent(oauth2AccessTokenProvider ->
|
||||||
|
requestBuilder.addHeader("Authorization", "Bearer " + oauth2AccessTokenProvider.getAccessDetails().getAccessToken())
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// set the request method
|
// set the request method
|
||||||
|
|
|
@ -21,6 +21,7 @@ import okhttp3.mockwebserver.MockWebServer;
|
||||||
import okhttp3.mockwebserver.RecordedRequest;
|
import okhttp3.mockwebserver.RecordedRequest;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.nifi.flowfile.attributes.CoreAttributes;
|
import org.apache.nifi.flowfile.attributes.CoreAttributes;
|
||||||
|
import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
|
||||||
import org.apache.nifi.processor.Relationship;
|
import org.apache.nifi.processor.Relationship;
|
||||||
import org.apache.nifi.processors.standard.http.FlowFileNamingStrategy;
|
import org.apache.nifi.processors.standard.http.FlowFileNamingStrategy;
|
||||||
import org.apache.nifi.reporting.InitializationException;
|
import org.apache.nifi.reporting.InitializationException;
|
||||||
|
@ -34,6 +35,7 @@ import org.apache.nifi.util.MockFlowFile;
|
||||||
import org.apache.nifi.util.TestRunner;
|
import org.apache.nifi.util.TestRunner;
|
||||||
import org.apache.nifi.util.TestRunners;
|
import org.apache.nifi.util.TestRunners;
|
||||||
import org.apache.nifi.web.util.ssl.SslContextUtils;
|
import org.apache.nifi.web.util.ssl.SslContextUtils;
|
||||||
|
import org.mockito.Answers;
|
||||||
|
|
||||||
import javax.net.ssl.SSLContext;
|
import javax.net.ssl.SSLContext;
|
||||||
import javax.net.ssl.SSLSocketFactory;
|
import javax.net.ssl.SSLSocketFactory;
|
||||||
|
@ -774,6 +776,99 @@ public class InvokeHTTPTest {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testValidWhenOAuth2Set() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
|
||||||
|
|
||||||
|
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
|
||||||
|
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
|
||||||
|
|
||||||
|
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
|
||||||
|
runner.enableControllerService(oauth2AccessTokenProvider);
|
||||||
|
|
||||||
|
setUrlProperty();
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
runner.assertValid();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidWhenOAuth2AndUserNameSet() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
|
||||||
|
|
||||||
|
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
|
||||||
|
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
|
||||||
|
|
||||||
|
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
|
||||||
|
runner.enableControllerService(oauth2AccessTokenProvider);
|
||||||
|
|
||||||
|
setUrlProperty();
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
|
||||||
|
runner.setProperty(InvokeHTTP.PROP_BASIC_AUTH_USERNAME, "userName");
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
runner.assertNotValid();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidWhenOAuth2AndPasswordSet() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
|
||||||
|
|
||||||
|
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
|
||||||
|
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
|
||||||
|
|
||||||
|
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
|
||||||
|
runner.enableControllerService(oauth2AccessTokenProvider);
|
||||||
|
|
||||||
|
setUrlProperty();
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
|
||||||
|
runner.setProperty(InvokeHTTP.PROP_BASIC_AUTH_PASSWORD, "password");
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
runner.assertNotValid();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOAuth2AuthorizationHeader() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String accessToken = "access_token";
|
||||||
|
|
||||||
|
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
|
||||||
|
|
||||||
|
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
|
||||||
|
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
|
||||||
|
when(oauth2AccessTokenProvider.getAccessDetails().getAccessToken()).thenReturn(accessToken);
|
||||||
|
|
||||||
|
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
|
||||||
|
runner.enableControllerService(oauth2AccessTokenProvider);
|
||||||
|
|
||||||
|
setUrlProperty();
|
||||||
|
|
||||||
|
mockWebServer.enqueue(new MockResponse());
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
|
||||||
|
runner.enqueue("unimportant");
|
||||||
|
runner.run();
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
RecordedRequest recordedRequest = mockWebServer.takeRequest();
|
||||||
|
|
||||||
|
String actualAuthorizationHeader = recordedRequest.getHeader("Authorization");
|
||||||
|
assertEquals("Bearer " + accessToken, actualAuthorizationHeader);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
private void setUrlProperty() {
|
private void setUrlProperty() {
|
||||||
runner.setProperty(InvokeHTTP.PROP_URL, getMockWebServerUrl());
|
runner.setProperty(InvokeHTTP.PROP_URL, getMockWebServerUrl());
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,53 +17,80 @@
|
||||||
|
|
||||||
package org.apache.nifi.oauth2;
|
package org.apache.nifi.oauth2;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.time.Instant;
|
||||||
|
|
||||||
public class AccessToken {
|
public class AccessToken {
|
||||||
private String accessToken;
|
private String accessToken;
|
||||||
private String refreshToken;
|
private String refreshToken;
|
||||||
private String tokenType;
|
private String tokenType;
|
||||||
private Integer expires;
|
private long expiresIn;
|
||||||
private String scope;
|
private String scopes;
|
||||||
|
|
||||||
private Long fetchTime;
|
private final Instant fetchTime;
|
||||||
|
|
||||||
public AccessToken(String accessToken,
|
public static final int EXPIRY_MARGIN = 5000;
|
||||||
String refreshToken,
|
|
||||||
String tokenType,
|
public AccessToken() {
|
||||||
Integer expires,
|
this.fetchTime = Instant.now();
|
||||||
String scope) {
|
}
|
||||||
|
|
||||||
|
public AccessToken(String accessToken, String refreshToken, String tokenType, long expiresIn, String scopes) {
|
||||||
|
this();
|
||||||
this.accessToken = accessToken;
|
this.accessToken = accessToken;
|
||||||
this.tokenType = tokenType;
|
|
||||||
this.refreshToken = refreshToken;
|
this.refreshToken = refreshToken;
|
||||||
this.expires = expires;
|
this.tokenType = tokenType;
|
||||||
this.scope = scope;
|
this.expiresIn = expiresIn;
|
||||||
this.fetchTime = System.currentTimeMillis();
|
this.scopes = scopes;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getAccessToken() {
|
public String getAccessToken() {
|
||||||
return accessToken;
|
return accessToken;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setAccessToken(String accessToken) {
|
||||||
|
this.accessToken = accessToken;
|
||||||
|
}
|
||||||
|
|
||||||
public String getRefreshToken() {
|
public String getRefreshToken() {
|
||||||
return refreshToken;
|
return refreshToken;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setRefreshToken(String refreshToken) {
|
||||||
|
this.refreshToken = refreshToken;
|
||||||
|
}
|
||||||
|
|
||||||
public String getTokenType() {
|
public String getTokenType() {
|
||||||
return tokenType;
|
return tokenType;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Integer getExpires() {
|
public void setTokenType(String tokenType) {
|
||||||
return expires;
|
this.tokenType = tokenType;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getScope() {
|
public long getExpiresIn() {
|
||||||
return scope;
|
return expiresIn;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long getFetchTime() {
|
public void setExpiresIn(long expiresIn) {
|
||||||
|
this.expiresIn = expiresIn;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getScopes() {
|
||||||
|
return scopes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setScopes(String scopes) {
|
||||||
|
this.scopes = scopes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Instant getFetchTime() {
|
||||||
return fetchTime;
|
return fetchTime;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isExpired() {
|
public boolean isExpired() {
|
||||||
return System.currentTimeMillis() >= ( fetchTime + (expires * 1000) );
|
boolean expired = Duration.between(Instant.now(), fetchTime.plusSeconds(expiresIn - EXPIRY_MARGIN)).isNegative();
|
||||||
|
|
||||||
|
return expired;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.nifi.oauth2;
|
||||||
|
|
||||||
|
import org.apache.nifi.controller.ControllerService;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Controller service that provides OAuth2 access details
|
||||||
|
*/
|
||||||
|
public interface OAuth2AccessTokenProvider extends ControllerService {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return A valid access token (refreshed automatically if needed) and additional metadata (provided by the OAuth2 access server)
|
||||||
|
*/
|
||||||
|
AccessToken getAccessDetails();
|
||||||
|
}
|
|
@ -29,7 +29,10 @@ import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Interface for defining a credential-providing controller service for oauth2 processes.
|
* Interface for defining a credential-providing controller service for oauth2 processes.
|
||||||
|
*
|
||||||
|
* @deprecated use {@link OAuth2AccessTokenProvider} instead
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public interface OAuth2TokenProvider extends ControllerService {
|
public interface OAuth2TokenProvider extends ControllerService {
|
||||||
PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
|
PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
|
||||||
.name("oauth2-ssl-context")
|
.name("oauth2-ssl-context")
|
||||||
|
|
|
@ -30,6 +30,7 @@ import okhttp3.Request;
|
||||||
import okhttp3.RequestBody;
|
import okhttp3.RequestBody;
|
||||||
import okhttp3.Response;
|
import okhttp3.Response;
|
||||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||||
|
import org.apache.nifi.annotation.documentation.DeprecationNotice;
|
||||||
import org.apache.nifi.annotation.documentation.Tags;
|
import org.apache.nifi.annotation.documentation.Tags;
|
||||||
import org.apache.nifi.annotation.lifecycle.OnEnabled;
|
import org.apache.nifi.annotation.lifecycle.OnEnabled;
|
||||||
import org.apache.nifi.components.PropertyDescriptor;
|
import org.apache.nifi.components.PropertyDescriptor;
|
||||||
|
@ -39,6 +40,8 @@ import org.apache.nifi.processor.exception.ProcessException;
|
||||||
import org.apache.nifi.ssl.SSLContextService;
|
import org.apache.nifi.ssl.SSLContextService;
|
||||||
import org.apache.nifi.util.StringUtils;
|
import org.apache.nifi.util.StringUtils;
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
@DeprecationNotice(alternatives = {StandardOauth2AccessTokenProvider.class})
|
||||||
@Tags({"oauth2", "provider", "authorization" })
|
@Tags({"oauth2", "provider", "authorization" })
|
||||||
@CapabilityDescription("This controller service provides a way of working with access and refresh tokens via the " +
|
@CapabilityDescription("This controller service provides a way of working with access and refresh tokens via the " +
|
||||||
"password and client_credential grant flows in the OAuth2 specification. It is meant to provide a way for components " +
|
"password and client_credential grant flows in the OAuth2 specification. It is meant to provide a way for components " +
|
||||||
|
|
|
@ -0,0 +1,300 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.nifi.oauth2;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
|
||||||
|
import okhttp3.FormBody;
|
||||||
|
import okhttp3.OkHttpClient;
|
||||||
|
import okhttp3.Request;
|
||||||
|
import okhttp3.RequestBody;
|
||||||
|
import okhttp3.Response;
|
||||||
|
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||||
|
import org.apache.nifi.annotation.documentation.Tags;
|
||||||
|
import org.apache.nifi.annotation.lifecycle.OnDisabled;
|
||||||
|
import org.apache.nifi.annotation.lifecycle.OnEnabled;
|
||||||
|
import org.apache.nifi.components.AllowableValue;
|
||||||
|
import org.apache.nifi.components.PropertyDescriptor;
|
||||||
|
import org.apache.nifi.components.ValidationContext;
|
||||||
|
import org.apache.nifi.components.ValidationResult;
|
||||||
|
import org.apache.nifi.components.Validator;
|
||||||
|
import org.apache.nifi.controller.AbstractControllerService;
|
||||||
|
import org.apache.nifi.controller.ConfigurationContext;
|
||||||
|
import org.apache.nifi.expression.ExpressionLanguageScope;
|
||||||
|
import org.apache.nifi.processor.exception.ProcessException;
|
||||||
|
import org.apache.nifi.processor.util.StandardValidators;
|
||||||
|
import org.apache.nifi.ssl.SSLContextService;
|
||||||
|
|
||||||
|
import javax.net.ssl.SSLContext;
|
||||||
|
import javax.net.ssl.X509TrustManager;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.UncheckedIOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Tags({"oauth2", "provider", "authorization", "access token", "http"})
|
||||||
|
@CapabilityDescription("Provides OAuth 2.0 access tokens that can be used as Bearer authorization header in HTTP requests." +
|
||||||
|
" Uses Resource Owner Password Credentials Grant.")
|
||||||
|
public class StandardOauth2AccessTokenProvider extends AbstractControllerService implements OAuth2AccessTokenProvider {
|
||||||
|
public static final PropertyDescriptor AUTHORIZATION_SERVER_URL = new PropertyDescriptor.Builder()
|
||||||
|
.name("authorization-server-url")
|
||||||
|
.displayName("Authorization Server URL")
|
||||||
|
.description("The URL of the authorization server that issues access tokens.")
|
||||||
|
.required(true)
|
||||||
|
.addValidator(StandardValidators.URL_VALIDATOR)
|
||||||
|
.expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
public static AllowableValue RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE = new AllowableValue(
|
||||||
|
"password",
|
||||||
|
"User Password",
|
||||||
|
"Resource Owner Password Credentials Grant. Used to access resources available to users. Requires username and password and usually Client ID and Client Secret"
|
||||||
|
);
|
||||||
|
public static AllowableValue CLIENT_CREDENTIALS_GRANT_TYPE = new AllowableValue(
|
||||||
|
"client_credentials",
|
||||||
|
"Client Credentials",
|
||||||
|
"Client Credentials Grant. Used to access resources available to clients. Requires Client ID and Client Secret"
|
||||||
|
);
|
||||||
|
|
||||||
|
public static final PropertyDescriptor GRANT_TYPE = new PropertyDescriptor.Builder()
|
||||||
|
.name("grant-type")
|
||||||
|
.displayName("Grant Type")
|
||||||
|
.description("The OAuth2 Grant Type to be used when acquiring an access token.")
|
||||||
|
.required(true)
|
||||||
|
.allowableValues(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE, CLIENT_CREDENTIALS_GRANT_TYPE)
|
||||||
|
.defaultValue(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
public static final PropertyDescriptor USERNAME = new PropertyDescriptor.Builder()
|
||||||
|
.name("service-user-name")
|
||||||
|
.displayName("Username")
|
||||||
|
.description("Username on the service that is being accessed.")
|
||||||
|
.dependsOn(GRANT_TYPE, RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE)
|
||||||
|
.required(true)
|
||||||
|
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
|
||||||
|
.expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
public static final PropertyDescriptor PASSWORD = new PropertyDescriptor.Builder()
|
||||||
|
.name("service-password")
|
||||||
|
.displayName("Password")
|
||||||
|
.description("Password for the username on the service that is being accessed.")
|
||||||
|
.dependsOn(GRANT_TYPE, RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE)
|
||||||
|
.required(true)
|
||||||
|
.sensitive(true)
|
||||||
|
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
public static final PropertyDescriptor CLIENT_ID = new PropertyDescriptor.Builder()
|
||||||
|
.name("client-id")
|
||||||
|
.displayName("Client ID")
|
||||||
|
.required(false)
|
||||||
|
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
|
||||||
|
.expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
public static final PropertyDescriptor CLIENT_SECRET = new PropertyDescriptor.Builder()
|
||||||
|
.name("client-secret")
|
||||||
|
.displayName("Client secret")
|
||||||
|
.dependsOn(CLIENT_ID)
|
||||||
|
.required(true)
|
||||||
|
.sensitive(true)
|
||||||
|
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
public static final PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
|
||||||
|
.name("ssl-context-service")
|
||||||
|
.displayName("SSL Context Servuce")
|
||||||
|
.addValidator(Validator.VALID)
|
||||||
|
.identifiesControllerService(SSLContextService.class)
|
||||||
|
.required(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
private static final List<PropertyDescriptor> PROPERTIES = Collections.unmodifiableList(Arrays.asList(
|
||||||
|
AUTHORIZATION_SERVER_URL,
|
||||||
|
GRANT_TYPE,
|
||||||
|
USERNAME,
|
||||||
|
PASSWORD,
|
||||||
|
CLIENT_ID,
|
||||||
|
CLIENT_SECRET,
|
||||||
|
SSL_CONTEXT
|
||||||
|
));
|
||||||
|
|
||||||
|
public static final ObjectMapper ACCESS_DETAILS_MAPPER = new ObjectMapper()
|
||||||
|
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
|
||||||
|
.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);
|
||||||
|
|
||||||
|
private volatile String authorizationServerUrl;
|
||||||
|
private volatile OkHttpClient httpClient;
|
||||||
|
|
||||||
|
private volatile String grantType;
|
||||||
|
private volatile String username;
|
||||||
|
private volatile String password;
|
||||||
|
private volatile String clientId;
|
||||||
|
private volatile String clientSecret;
|
||||||
|
|
||||||
|
private volatile AccessToken accessDetails;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<PropertyDescriptor> getSupportedPropertyDescriptors() {
|
||||||
|
return PROPERTIES;
|
||||||
|
}
|
||||||
|
|
||||||
|
@OnEnabled
|
||||||
|
public void onEnabled(ConfigurationContext context) {
|
||||||
|
authorizationServerUrl = context.getProperty(AUTHORIZATION_SERVER_URL).evaluateAttributeExpressions().getValue();
|
||||||
|
|
||||||
|
httpClient = createHttpClient(context);
|
||||||
|
|
||||||
|
grantType = context.getProperty(GRANT_TYPE).getValue();
|
||||||
|
username = context.getProperty(USERNAME).evaluateAttributeExpressions().getValue();
|
||||||
|
password = context.getProperty(PASSWORD).getValue();
|
||||||
|
clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
|
||||||
|
clientSecret = context.getProperty(CLIENT_SECRET).getValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
@OnDisabled
|
||||||
|
public void onDisabled() {
|
||||||
|
accessDetails = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Collection<ValidationResult> customValidate(ValidationContext validationContext) {
|
||||||
|
final List<ValidationResult> validationResults = new ArrayList<>(super.customValidate(validationContext));
|
||||||
|
|
||||||
|
if (
|
||||||
|
validationContext.getProperty(GRANT_TYPE).getValue().equals(CLIENT_CREDENTIALS_GRANT_TYPE.getValue())
|
||||||
|
&& !validationContext.getProperty(CLIENT_ID).isSet()
|
||||||
|
) {
|
||||||
|
validationResults.add(new ValidationResult.Builder().subject(CLIENT_ID.getDisplayName())
|
||||||
|
.valid(false)
|
||||||
|
.explanation(String.format(
|
||||||
|
"When '%s' is set to '%s', '%s' is required",
|
||||||
|
GRANT_TYPE.getDisplayName(),
|
||||||
|
CLIENT_CREDENTIALS_GRANT_TYPE.getDisplayName(),
|
||||||
|
CLIENT_ID.getDisplayName())
|
||||||
|
)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
|
||||||
|
return validationResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected OkHttpClient createHttpClient(ConfigurationContext context) {
|
||||||
|
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder();
|
||||||
|
|
||||||
|
SSLContextService sslService = context.getProperty(SSL_CONTEXT).asControllerService(SSLContextService.class);
|
||||||
|
if (sslService != null) {
|
||||||
|
final X509TrustManager trustManager = sslService.createTrustManager();
|
||||||
|
SSLContext sslContext = sslService.createContext();
|
||||||
|
clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientBuilder.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AccessToken getAccessDetails() {
|
||||||
|
if (this.accessDetails == null) {
|
||||||
|
acquireAccessDetails();
|
||||||
|
} else if (this.accessDetails.isExpired()) {
|
||||||
|
if (this.accessDetails.getRefreshToken() == null) {
|
||||||
|
acquireAccessDetails();
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
refreshAccessDetails();
|
||||||
|
} catch (Exception e) {
|
||||||
|
getLogger().info("Couldn't refresh access token", e);
|
||||||
|
acquireAccessDetails();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return accessDetails;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void acquireAccessDetails() {
|
||||||
|
getLogger().debug("Getting a new access token");
|
||||||
|
|
||||||
|
FormBody.Builder acquireTokenBuilder = new FormBody.Builder();
|
||||||
|
|
||||||
|
if (grantType.equals(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue())) {
|
||||||
|
acquireTokenBuilder.add("grant_type", "password")
|
||||||
|
.add("username", username)
|
||||||
|
.add("password", password);
|
||||||
|
} else if (grantType.equals(CLIENT_CREDENTIALS_GRANT_TYPE.getValue())) {
|
||||||
|
acquireTokenBuilder.add("grant_type", "client_credentials");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (clientId != null) {
|
||||||
|
acquireTokenBuilder.add("client_id", clientId);
|
||||||
|
acquireTokenBuilder.add("client_secret", clientSecret);
|
||||||
|
}
|
||||||
|
|
||||||
|
RequestBody acquireTokenRequestBody = acquireTokenBuilder.build();
|
||||||
|
|
||||||
|
Request acquireTokenRequest = new Request.Builder()
|
||||||
|
.url(authorizationServerUrl)
|
||||||
|
.post(acquireTokenRequestBody)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
this.accessDetails = getAccessDetails(acquireTokenRequest);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void refreshAccessDetails() {
|
||||||
|
getLogger().debug("Refreshing access token");
|
||||||
|
|
||||||
|
FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
|
||||||
|
.add("grant_type", "refresh_token")
|
||||||
|
.add("refresh_token", this.accessDetails.getRefreshToken());
|
||||||
|
|
||||||
|
if (clientId != null) {
|
||||||
|
refreshTokenBuilder.add("client_id", clientId);
|
||||||
|
refreshTokenBuilder.add("client_secret", clientSecret);
|
||||||
|
}
|
||||||
|
|
||||||
|
RequestBody refreshTokenRequestBody = refreshTokenBuilder.build();
|
||||||
|
|
||||||
|
Request refreshRequest = new Request.Builder()
|
||||||
|
.url(authorizationServerUrl)
|
||||||
|
.post(refreshTokenRequestBody)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
this.accessDetails = getAccessDetails(refreshRequest);
|
||||||
|
}
|
||||||
|
|
||||||
|
private AccessToken getAccessDetails(Request newRequest) {
|
||||||
|
try {
|
||||||
|
Response response = httpClient.newCall(newRequest).execute();
|
||||||
|
String responseBody = response.body().string();
|
||||||
|
if (response.code() != 200) {
|
||||||
|
getLogger().error(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", response.code(), responseBody));
|
||||||
|
throw new ProcessException(String.format("OAuth2 access token request failed [HTTP %d]", response.code()));
|
||||||
|
}
|
||||||
|
|
||||||
|
AccessToken accessDetails = ACCESS_DETAILS_MAPPER.readValue(responseBody, AccessToken.class);
|
||||||
|
|
||||||
|
return accessDetails;
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new UncheckedIOException("OAuth2 access token request failed", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -13,3 +13,5 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
org.apache.nifi.oauth2.OAuth2TokenProviderImpl
|
org.apache.nifi.oauth2.OAuth2TokenProviderImpl
|
||||||
|
org.apache.nifi.oauth2.StandardOauth2AccessTokenProvider
|
||||||
|
|
||||||
|
|
|
@ -107,7 +107,7 @@ public class OAuth2TokenProviderImplTest {
|
||||||
private void assertAccessTokenFound(final AccessToken accessToken) {
|
private void assertAccessTokenFound(final AccessToken accessToken) {
|
||||||
assertNotNull(accessToken);
|
assertNotNull(accessToken);
|
||||||
assertEquals("access token", accessToken.getAccessToken());
|
assertEquals("access token", accessToken.getAccessToken());
|
||||||
assertEquals(300, accessToken.getExpires().intValue());
|
assertEquals(5300, accessToken.getExpiresIn());
|
||||||
assertEquals("BEARER", accessToken.getTokenType());
|
assertEquals("BEARER", accessToken.getTokenType());
|
||||||
assertFalse(accessToken.isExpired());
|
assertFalse(accessToken.isExpired());
|
||||||
}
|
}
|
||||||
|
@ -117,7 +117,7 @@ public class OAuth2TokenProviderImplTest {
|
||||||
token.put("access_token", "access token");
|
token.put("access_token", "access token");
|
||||||
token.put("refresh_token", "refresh token");
|
token.put("refresh_token", "refresh token");
|
||||||
token.put("token_type", "BEARER");
|
token.put("token_type", "BEARER");
|
||||||
token.put("expires_in", 300);
|
token.put("expires_in", 5300);
|
||||||
token.put("scope", "test scope");
|
token.put("scope", "test scope");
|
||||||
final String accessToken = new ObjectMapper().writeValueAsString(token);
|
final String accessToken = new ObjectMapper().writeValueAsString(token);
|
||||||
mockWebServer.enqueue(new MockResponse().setResponseCode(200).addHeader("Content-Type", "application/json").setBody(accessToken));
|
mockWebServer.enqueue(new MockResponse().setResponseCode(200).addHeader("Content-Type", "application/json").setBody(accessToken));
|
||||||
|
|
|
@ -0,0 +1,411 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.nifi.oauth2;
|
||||||
|
|
||||||
|
import okhttp3.MediaType;
|
||||||
|
import okhttp3.OkHttpClient;
|
||||||
|
import okhttp3.Protocol;
|
||||||
|
import okhttp3.Request;
|
||||||
|
import okhttp3.Response;
|
||||||
|
import okhttp3.ResponseBody;
|
||||||
|
import org.apache.nifi.controller.ConfigurationContext;
|
||||||
|
import org.apache.nifi.logging.ComponentLog;
|
||||||
|
import org.apache.nifi.processor.Processor;
|
||||||
|
import org.apache.nifi.processor.exception.ProcessException;
|
||||||
|
import org.apache.nifi.util.NoOpProcessor;
|
||||||
|
import org.apache.nifi.util.TestRunner;
|
||||||
|
import org.apache.nifi.util.TestRunners;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.mockito.Answers;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.Captor;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.MockitoAnnotations;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.UncheckedIOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertThrows;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class StandardOauth2AccessTokenProviderTest {
|
||||||
|
private static final String AUTHORIZATION_SERVER_URL = "http://authorizationServerUrl";
|
||||||
|
private static final String USERNAME = "username";
|
||||||
|
private static final String PASSWORD = "password";
|
||||||
|
private static final String CLIENT_ID = "clientId";
|
||||||
|
private static final String CLIENT_SECRET = "clientSecret";
|
||||||
|
|
||||||
|
private StandardOauth2AccessTokenProvider testSubject;
|
||||||
|
|
||||||
|
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
|
||||||
|
private OkHttpClient mockHttpClient;
|
||||||
|
|
||||||
|
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
|
||||||
|
private ConfigurationContext mockContext;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private ComponentLog mockLogger;
|
||||||
|
@Captor
|
||||||
|
private ArgumentCaptor<String> debugCaptor;
|
||||||
|
@Captor
|
||||||
|
private ArgumentCaptor<String> errorCaptor;
|
||||||
|
@Captor
|
||||||
|
private ArgumentCaptor<Throwable> throwableCaptor;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() throws Exception {
|
||||||
|
MockitoAnnotations.initMocks(this);
|
||||||
|
|
||||||
|
testSubject = new StandardOauth2AccessTokenProvider() {
|
||||||
|
@Override
|
||||||
|
protected OkHttpClient createHttpClient(ConfigurationContext context) {
|
||||||
|
return mockHttpClient;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected ComponentLog getLogger() {
|
||||||
|
return mockLogger;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue());
|
||||||
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL).evaluateAttributeExpressions().getValue()).thenReturn(AUTHORIZATION_SERVER_URL);
|
||||||
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.USERNAME).evaluateAttributeExpressions().getValue()).thenReturn(USERNAME);
|
||||||
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
|
||||||
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
|
||||||
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
|
||||||
|
|
||||||
|
testSubject.onEnabled(mockContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidWhenClientCredentialsGrantTypeSetWithoutClientId() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
Processor processor = new NoOpProcessor();
|
||||||
|
TestRunner runner = TestRunners.newTestRunner(processor);
|
||||||
|
|
||||||
|
runner.addControllerService("testSubject", testSubject);
|
||||||
|
|
||||||
|
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
runner.assertNotValid(testSubject);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testValidWhenClientCredentialsGrantTypeSetWithClientId() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
Processor processor = new NoOpProcessor();
|
||||||
|
TestRunner runner = TestRunners.newTestRunner(processor);
|
||||||
|
|
||||||
|
runner.addControllerService("testSubject", testSubject);
|
||||||
|
|
||||||
|
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
|
||||||
|
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, "clientId");
|
||||||
|
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, "clientSecret");
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
runner.assertValid(testSubject);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAcquireNewToken() throws Exception {
|
||||||
|
String accessTokenValue = "access_token_value";
|
||||||
|
|
||||||
|
// GIVEN
|
||||||
|
Response response = buildResponse(
|
||||||
|
200,
|
||||||
|
"{ \"access_token\":\"" + accessTokenValue + "\" }"
|
||||||
|
);
|
||||||
|
|
||||||
|
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
String actual = testSubject.getAccessDetails().getAccessToken();
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
assertEquals(accessTokenValue, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRefreshToken() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String firstToken = "first_token";
|
||||||
|
String expectedToken = "second_token";
|
||||||
|
|
||||||
|
Response response1 = buildResponse(
|
||||||
|
200,
|
||||||
|
"{ \"access_token\":\"" + firstToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
|
||||||
|
);
|
||||||
|
|
||||||
|
Response response2 = buildResponse(
|
||||||
|
200,
|
||||||
|
"{ \"access_token\":\"" + expectedToken + "\" }"
|
||||||
|
);
|
||||||
|
|
||||||
|
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
testSubject.getAccessDetails();
|
||||||
|
String actualToken = testSubject.getAccessDetails().getAccessToken();
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
assertEquals(expectedToken, actualToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String refreshErrorMessage = "refresh_error";
|
||||||
|
String acquireErrorMessage = "acquire_error";
|
||||||
|
|
||||||
|
AtomicInteger callCounter = new AtomicInteger(0);
|
||||||
|
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
|
||||||
|
callCounter.incrementAndGet();
|
||||||
|
|
||||||
|
if (callCounter.get() == 1) {
|
||||||
|
return buildSuccessfulInitResponse();
|
||||||
|
} else if (callCounter.get() == 2) {
|
||||||
|
throw new IOException(refreshErrorMessage);
|
||||||
|
} else if (callCounter.get() == 3) {
|
||||||
|
throw new IOException(acquireErrorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Get a good accessDetails so we can have a refresh a second time
|
||||||
|
testSubject.getAccessDetails();
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
UncheckedIOException actualException = assertThrows(
|
||||||
|
UncheckedIOException.class,
|
||||||
|
() -> testSubject.getAccessDetails()
|
||||||
|
);
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
checkLoggedDebugWhenRefreshFails();
|
||||||
|
|
||||||
|
checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
|
||||||
|
|
||||||
|
checkError(new UncheckedIOException("OAuth2 access token request failed", new IOException(acquireErrorMessage)), actualException);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIOExceptionDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String refreshErrorMessage = "refresh_error";
|
||||||
|
String expectedToken = "expected_token";
|
||||||
|
|
||||||
|
Response successfulAcquireResponse = buildResponse(
|
||||||
|
200,
|
||||||
|
"{ \"access_token\":\"" + expectedToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
|
||||||
|
);
|
||||||
|
|
||||||
|
AtomicInteger callCounter = new AtomicInteger(0);
|
||||||
|
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
|
||||||
|
callCounter.incrementAndGet();
|
||||||
|
|
||||||
|
if (callCounter.get() == 1) {
|
||||||
|
return buildSuccessfulInitResponse();
|
||||||
|
} else if (callCounter.get() == 2) {
|
||||||
|
throw new IOException(refreshErrorMessage);
|
||||||
|
} else if (callCounter.get() == 3) {
|
||||||
|
return successfulAcquireResponse;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Get a good accessDetails so we can have a refresh a second time
|
||||||
|
testSubject.getAccessDetails();
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
String actualToken = testSubject.getAccessDetails().getAccessToken();
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
checkLoggedDebugWhenRefreshFails();
|
||||||
|
|
||||||
|
checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
|
||||||
|
|
||||||
|
assertEquals(expectedToken, actualToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testHTTPErrorDuringRefreshAndSubsequentAcquire() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String errorRefreshResponseBody = "{ \"error_response\":\"refresh_error\" }";
|
||||||
|
String errorAcquireResponseBody = "{ \"error_response\":\"acquire_error\" }";
|
||||||
|
|
||||||
|
Response errorRefreshResponse = buildResponse(500, errorRefreshResponseBody);
|
||||||
|
Response errorAcquireResponse = buildResponse(503, errorAcquireResponseBody);
|
||||||
|
|
||||||
|
AtomicInteger callCounter = new AtomicInteger(0);
|
||||||
|
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
|
||||||
|
callCounter.incrementAndGet();
|
||||||
|
|
||||||
|
if (callCounter.get() == 1) {
|
||||||
|
return buildSuccessfulInitResponse();
|
||||||
|
} else if (callCounter.get() == 2) {
|
||||||
|
return errorRefreshResponse;
|
||||||
|
} else if (callCounter.get() == 3) {
|
||||||
|
return errorAcquireResponse;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
|
||||||
|
});
|
||||||
|
|
||||||
|
List<String> expectedLoggedError = Arrays.asList(
|
||||||
|
String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, errorRefreshResponseBody),
|
||||||
|
String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 503, errorAcquireResponseBody)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Get a good accessDetails so we can have a refresh a second time
|
||||||
|
testSubject.getAccessDetails();
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
ProcessException actualException = assertThrows(
|
||||||
|
ProcessException.class,
|
||||||
|
() -> testSubject.getAccessDetails()
|
||||||
|
);
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
checkLoggedDebugWhenRefreshFails();
|
||||||
|
|
||||||
|
checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
|
||||||
|
|
||||||
|
checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(expectedLoggedError);
|
||||||
|
|
||||||
|
checkError(new ProcessException("OAuth2 access token request failed [HTTP 503]"), actualException);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testHTTPErrorDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String expectedRefreshErrorResponse = "{ \"error_response\":\"refresh_error\" }";
|
||||||
|
String expectedToken = "expected_token";
|
||||||
|
|
||||||
|
Response errorRefreshResponse = buildResponse(500, expectedRefreshErrorResponse);
|
||||||
|
Response successfulAcquireResponse = buildResponse(
|
||||||
|
200,
|
||||||
|
"{ \"access_token\":\"" + expectedToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
|
||||||
|
);
|
||||||
|
|
||||||
|
AtomicInteger callCounter = new AtomicInteger(0);
|
||||||
|
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
|
||||||
|
callCounter.incrementAndGet();
|
||||||
|
|
||||||
|
if (callCounter.get() == 1) {
|
||||||
|
return buildSuccessfulInitResponse();
|
||||||
|
} else if (callCounter.get() == 2) {
|
||||||
|
return errorRefreshResponse;
|
||||||
|
} else if (callCounter.get() == 3) {
|
||||||
|
return successfulAcquireResponse;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
|
||||||
|
});
|
||||||
|
|
||||||
|
List<String> expectedLoggedError = Arrays.asList(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, expectedRefreshErrorResponse));
|
||||||
|
|
||||||
|
// Get a good accessDetails so we can have a refresh a second time
|
||||||
|
testSubject.getAccessDetails();
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
String actualToken = testSubject.getAccessDetails().getAccessToken();
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
checkLoggedDebugWhenRefreshFails();
|
||||||
|
|
||||||
|
checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
|
||||||
|
|
||||||
|
checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(expectedLoggedError);
|
||||||
|
|
||||||
|
assertEquals(expectedToken, actualToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Response buildSuccessfulInitResponse() {
|
||||||
|
return buildResponse(
|
||||||
|
200,
|
||||||
|
"{ \"access_token\":\"exists_but_value_irrelevant\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Response buildResponse(int code, String body) {
|
||||||
|
return new Response.Builder()
|
||||||
|
.request(new Request.Builder()
|
||||||
|
.url("http://unimportant_but_required")
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.protocol(Protocol.HTTP_2)
|
||||||
|
.message("unimportant_but_required")
|
||||||
|
.code(code)
|
||||||
|
.body(ResponseBody.create(
|
||||||
|
body.getBytes(),
|
||||||
|
MediaType.parse("application/json"))
|
||||||
|
)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkLoggedDebugWhenRefreshFails() {
|
||||||
|
verify(mockLogger, times(3)).debug(debugCaptor.capture());
|
||||||
|
List<String> actualDebugMessages = debugCaptor.getAllValues();
|
||||||
|
|
||||||
|
assertEquals(
|
||||||
|
Arrays.asList("Getting a new access token", "Refreshing access token", "Getting a new access token"),
|
||||||
|
actualDebugMessages
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(List<String> expectedLoggedError) {
|
||||||
|
verify(mockLogger, times(expectedLoggedError.size())).error(errorCaptor.capture());
|
||||||
|
List<String> actualLoggedError = errorCaptor.getAllValues();
|
||||||
|
|
||||||
|
assertEquals(expectedLoggedError, actualLoggedError);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkLoggedRefreshError(Throwable expectedRefreshError) {
|
||||||
|
verify(mockLogger).info(eq("Couldn't refresh access token"), throwableCaptor.capture());
|
||||||
|
Throwable actualRefreshError = throwableCaptor.getValue();
|
||||||
|
|
||||||
|
checkError(expectedRefreshError, actualRefreshError);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkError(Throwable expectedError, Throwable actualError) {
|
||||||
|
assertEquals(expectedError.getClass(), actualError.getClass());
|
||||||
|
assertEquals(expectedError.getMessage(), actualError.getMessage());
|
||||||
|
if (expectedError.getCause() != null || actualError.getCause() != null) {
|
||||||
|
assertEquals(expectedError.getCause().getClass(), actualError.getCause().getClass());
|
||||||
|
assertEquals(expectedError.getCause().getMessage(), actualError.getCause().getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue