diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/ClientAuthenticationStrategy.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/ClientAuthenticationStrategy.java new file mode 100644 index 0000000000..66cddc1e07 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/ClientAuthenticationStrategy.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.oauth2; + +import org.apache.nifi.components.DescribedValue; + +public enum ClientAuthenticationStrategy implements DescribedValue { + REQUEST_BODY("Send client authentication in request body. RFC 6749 Section 2.3.1 recommends Basic Authentication instead of request body."), + BASIC_AUTHENTICATION("Send client authentication using HTTP Basic authentication."); + + private final String description; + + ClientAuthenticationStrategy(final String description) { + this.description = description; + } + + @Override + public String getValue() { + return name(); + } + + @Override + public String getDisplayName() { + return name(); + } + + @Override + public String getDescription() { + return description; + } +} diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java index 22dce3a033..39b74ff88c 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java @@ -19,6 +19,7 @@ package org.apache.nifi.oauth2; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import okhttp3.Credentials; import okhttp3.FormBody; import okhttp3.OkHttpClient; import okhttp3.Request; @@ -58,7 +59,8 @@ import java.util.concurrent.TimeUnit; @Tags({"oauth2", "provider", "authorization", "access token", "http"}) @CapabilityDescription("Provides OAuth 2.0 access tokens that can be used as Bearer authorization header in HTTP requests." + - " Uses Resource Owner Password Credentials Grant.") + " Can use either Resource Owner Password Credentials Grant or Client Credentials Grant." + + " Client authentication can be done with either HTTP Basic authentication or in the request body.") public class StandardOauth2AccessTokenProvider extends AbstractControllerService implements OAuth2AccessTokenProvider, VerifiableControllerService { public static final PropertyDescriptor AUTHORIZATION_SERVER_URL = new PropertyDescriptor.Builder() .name("authorization-server-url") @@ -69,6 +71,15 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY) .build(); + public static final PropertyDescriptor CLIENT_AUTHENTICATION_STRATEGY = new PropertyDescriptor.Builder() + .name("client-authentication-strategy") + .displayName("Client Authentication Strategy") + .description("Strategy for authenticating the client against the OAuth2 token provider service.") + .required(true) + .allowableValues(ClientAuthenticationStrategy.class) + .defaultValue(ClientAuthenticationStrategy.REQUEST_BODY.getValue()) + .build(); + public static AllowableValue RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE = new AllowableValue( "password", "User Password", @@ -136,13 +147,13 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService .build(); public static final PropertyDescriptor REFRESH_WINDOW = new PropertyDescriptor.Builder() - .name("refresh-window") - .displayName("Refresh Window") - .description("The service will attempt to refresh tokens expiring within the refresh window, subtracting the configured duration from the token expiration.") - .addValidator(StandardValidators.TIME_PERIOD_VALIDATOR) - .defaultValue("0 s") - .required(true) - .build(); + .name("refresh-window") + .displayName("Refresh Window") + .description("The service will attempt to refresh tokens expiring within the refresh window, subtracting the configured duration from the token expiration.") + .addValidator(StandardValidators.TIME_PERIOD_VALIDATOR) + .defaultValue("0 s") + .required(true) + .build(); public static final PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder() .name("ssl-context-service") @@ -163,6 +174,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService private static final List PROPERTIES = Collections.unmodifiableList(Arrays.asList( AUTHORIZATION_SERVER_URL, + CLIENT_AUTHENTICATION_STRATEGY, GRANT_TYPE, USERNAME, PASSWORD, @@ -174,6 +186,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService HTTP_PROTOCOL_STRATEGY )); + private static final String AUTHORIZATION_HEADER = "Authorization"; + public static final ObjectMapper ACCESS_DETAILS_MAPPER = new ObjectMapper() .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE); @@ -181,6 +195,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService private volatile String authorizationServerUrl; private volatile OkHttpClient httpClient; + private volatile ClientAuthenticationStrategy clientAuthenticationStrategy; private volatile String grantType; private volatile String username; private volatile String password; @@ -202,6 +217,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService httpClient = createHttpClient(context); + clientAuthenticationStrategy = ClientAuthenticationStrategy.valueOf(context.getProperty(CLIENT_AUTHENTICATION_STRATEGY).getValue()); grantType = context.getProperty(GRANT_TYPE).getValue(); username = context.getProperty(USERNAME).evaluateAttributeExpressions().getValue(); password = context.getProperty(PASSWORD).getValue(); @@ -288,7 +304,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService acquireTokenBuilder.add("grant_type", "client_credentials"); } - if (clientId != null) { + if (ClientAuthenticationStrategy.REQUEST_BODY == clientAuthenticationStrategy && clientId != null) { acquireTokenBuilder.add("client_id", clientId); acquireTokenBuilder.add("client_secret", clientSecret); } @@ -298,11 +314,15 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService } RequestBody acquireTokenRequestBody = acquireTokenBuilder.build(); - - Request acquireTokenRequest = new Request.Builder() + Request.Builder acquireTokenRequestBuilder = new Request.Builder() .url(authorizationServerUrl) - .post(acquireTokenRequestBody) - .build(); + .post(acquireTokenRequestBody); + + if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) { + acquireTokenRequestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret)); + } + + Request acquireTokenRequest = acquireTokenRequestBuilder.build(); this.accessDetails = getAccessDetails(acquireTokenRequest); } @@ -314,7 +334,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService .add("grant_type", "refresh_token") .add("refresh_token", this.accessDetails.getRefreshToken()); - if (clientId != null) { + if (ClientAuthenticationStrategy.REQUEST_BODY == clientAuthenticationStrategy && clientId != null) { refreshTokenBuilder.add("client_id", clientId); refreshTokenBuilder.add("client_secret", clientSecret); } @@ -325,10 +345,15 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService RequestBody refreshTokenRequestBody = refreshTokenBuilder.build(); - Request refreshRequest = new Request.Builder() + Request.Builder refreshRequestBuilder = new Request.Builder() .url(authorizationServerUrl) - .post(refreshTokenRequestBody) - .build(); + .post(refreshTokenRequestBody); + + if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) { + refreshRequestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret)); + } + + Request refreshRequest = refreshRequestBuilder.build(); this.accessDetails = getAccessDetails(refreshRequest); } diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java index 9aa1d5006d..e26fc6bf3f 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java +++ b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java @@ -22,6 +22,8 @@ import okhttp3.Protocol; import okhttp3.Request; import okhttp3.Response; import okhttp3.ResponseBody; +import okio.Buffer; + import org.apache.nifi.components.ConfigVerificationResult; import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.controller.ConfigurationContext; @@ -46,7 +48,9 @@ import org.mockito.junit.jupiter.MockitoExtension; import java.io.IOException; import java.io.UncheckedIOException; +import java.nio.charset.Charset; import java.util.Arrays; +import java.util.Base64; import java.util.Collections; import java.util.List; import java.util.Map; @@ -61,6 +65,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -91,6 +96,8 @@ public class StandardOauth2AccessTokenProviderTest { private ArgumentCaptor errorCaptor; @Captor private ArgumentCaptor throwableCaptor; + @Captor + private ArgumentCaptor requestCaptor; @BeforeEach public void setUp() { @@ -113,6 +120,7 @@ public class StandardOauth2AccessTokenProviderTest { when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID); when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET); when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES); + when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.BASIC_AUTHENTICATION.getValue()); testSubject.onEnabled(mockContext); } @@ -125,7 +133,7 @@ public class StandardOauth2AccessTokenProviderTest { runner.addControllerService("testSubject", testSubject); - runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant"); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, AUTHORIZATION_SERVER_URL); // WHEN runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE); @@ -142,12 +150,50 @@ public class StandardOauth2AccessTokenProviderTest { runner.addControllerService("testSubject", testSubject); - runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant"); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, AUTHORIZATION_SERVER_URL); // WHEN runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE); - runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, "clientId"); - runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, "clientSecret"); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, CLIENT_ID); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, CLIENT_SECRET); + + // THEN + runner.assertValid(testSubject); + } + + @Test + public void testInvalidWhenClientAuthenticationStrategyIsInvalid() throws Exception { + // GIVEN + Processor processor = new NoOpProcessor(); + TestRunner runner = TestRunners.newTestRunner(processor); + + runner.addControllerService("testSubject", testSubject); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, AUTHORIZATION_SERVER_URL); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, CLIENT_ID); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, CLIENT_SECRET); + + // WHEN + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY, "UNKNOWN"); + + // THEN + runner.assertNotValid(testSubject); + } + + @Test + public void testValidWhenClientAuthenticationStrategyIsValid() throws Exception { + // GIVEN + Processor processor = new NoOpProcessor(); + TestRunner runner = TestRunners.newTestRunner(processor); + + runner.addControllerService("testSubject", testSubject); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, AUTHORIZATION_SERVER_URL); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, CLIENT_ID); + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, CLIENT_SECRET); + + // WHEN + runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY, ClientAuthenticationStrategy.REQUEST_BODY.getValue()); // THEN runner.assertValid(testSubject); @@ -250,6 +296,42 @@ public class StandardOauth2AccessTokenProviderTest { assertEquals(expectedToken, actualToken); } + @Test + public void testBasicAuthentication() throws Exception { + // GIVEN + Response response = buildResponse(HTTP_OK, "{\"access_token\":\"foobar\"}"); + when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response); + String expected = "Basic " + Base64.getEncoder().withoutPadding().encodeToString((CLIENT_ID + ":" + CLIENT_SECRET).getBytes()); + + // WHEN + testSubject.getAccessDetails(); + + // THEN + verify(mockHttpClient, atLeast(1)).newCall(requestCaptor.capture()); + assertEquals(expected, requestCaptor.getValue().header("Authorization")); + } + + @Test + public void testRequestBodyAuthentication() throws Exception { + when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE.getValue()); + when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.REQUEST_BODY.getValue()); + testSubject.onEnabled(mockContext); + + // GIVEN + Response response = buildResponse(HTTP_OK, "{\"access_token\":\"foobar\"}"); + when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response); + String expected = "grant_type=client_credentials&client_id=" + CLIENT_ID + "&client_secret=" + CLIENT_SECRET; + + // WHEN + testSubject.getAccessDetails(); + + // THEN + Buffer buffer = new Buffer(); + verify(mockHttpClient, atLeast(1)).newCall(requestCaptor.capture()); + requestCaptor.getValue().body().writeTo(buffer); + assertEquals(expected, buffer.readString(Charset.defaultCharset())); + } + @Test public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception { // GIVEN