NIFI-11018 Upgraded nifi-web-security to JUnit 5

This closes #6814

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
dan-s1 2022-12-30 20:13:29 +00:00 committed by exceptionfactory
parent 1404a151a1
commit 481cdaf3db
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
23 changed files with 338 additions and 468 deletions

View File

@ -17,12 +17,8 @@
package org.apache.nifi.web.filter
import org.junit.After
import org.junit.Before
import org.junit.BeforeClass
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
@ -35,27 +31,16 @@ import javax.servlet.ServletResponse
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
@RunWith(JUnit4.class)
class CatchAllFilterTest extends GroovyTestCase {
class CatchAllFilterTest {
private static final Logger logger = LoggerFactory.getLogger(CatchAllFilterTest.class)
@BeforeClass
@BeforeAll
static void setUpOnce() throws Exception {
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Before
void setUp() throws Exception {
}
@After
void tearDown() throws Exception {
}
private static String getValue(String parameterName, Map<String, String> params = [:]) {
params.containsKey(parameterName) ? params[parameterName] : ""
}

View File

@ -17,24 +17,20 @@
package org.apache.nifi.web.security
import org.apache.nifi.authorization.user.NiFiUser
import org.junit.After
import org.junit.Before
import org.junit.BeforeClass
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.nio.charset.StandardCharsets
@RunWith(JUnit4.class)
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
class ProxiedEntitiesUtilsTest {
private static final Logger logger = LoggerFactory.getLogger(ProxiedEntitiesUtils.class)
private static final String SAFE_USER_NAME_ANDY = "alopresto"
private static final String SAFE_USER_DN_ANDY = "CN=${SAFE_USER_NAME_ANDY}, OU=Apache NiFi"
private static final String SAFE_USER_NAME_JOHN = "jdoe"
private static final String SAFE_USER_DN_JOHN = "CN=${SAFE_USER_NAME_JOHN}, OU=Apache NiFi"
@ -50,7 +46,6 @@ class ProxiedEntitiesUtilsTest {
private static
final String MALICIOUS_USER_NAME_JOHN_ESCAPED = sanitizeDn(MALICIOUS_USER_NAME_JOHN)
private static final String MALICIOUS_USER_DN_JOHN_ESCAPED = sanitizeDn(MALICIOUS_USER_DN_JOHN)
private static final String UNICODE_DN_1 = "CN=Алйс, OU=Apache NiFi"
private static final String UNICODE_DN_1_ENCODED = "<" + base64Encode(UNICODE_DN_1) + ">"
@ -58,21 +53,13 @@ class ProxiedEntitiesUtilsTest {
private static final String UNICODE_DN_2 = "CN=Боб, OU=Apache NiFi"
private static final String UNICODE_DN_2_ENCODED = "<" + base64Encode(UNICODE_DN_2) + ">"
@BeforeClass
@BeforeAll
static void setUpOnce() throws Exception {
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Before
void setUp() {
}
@After
void tearDown() {
}
private static String sanitizeDn(String dn = "") {
dn.replaceAll(/>/, '\\\\>').replaceAll('<', '\\\\<')
}
@ -82,7 +69,7 @@ class ProxiedEntitiesUtilsTest {
}
private static String printUnicodeString(final String raw) {
StringBuilder sb = new StringBuilder();
StringBuilder sb = new StringBuilder()
for (int i = 0; i < raw.size(); i++) {
int codePoint = Character.codePointAt(raw, i)
int charCount = Character.charCount(codePoint)
@ -120,7 +107,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Sanitized name: ${sanitizedName} | ${printUnicodeString(sanitizedName)}")
// Assert
assert sanitizedName != DESIRED_NAME
assertNotEquals(DESIRED_NAME, sanitizedName)
}
}
@ -138,7 +125,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Forjohned proxy DN: ${forjohnedProxyDn}")
// Assert
assert forjohnedProxyDn == EXPECTED_PROXY_DN
assertEquals(EXPECTED_PROXY_DN, forjohnedProxyDn)
}
@Test
@ -156,7 +143,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Forjohned proxy DN: ${forjohnedProxyDn}")
// Assert
assert forjohnedProxyDn == EXPECTED_PROXY_DN
assertEquals(EXPECTED_PROXY_DN, forjohnedProxyDn)
}
@Test
@ -169,7 +156,7 @@ class ProxiedEntitiesUtilsTest {
def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input)
// Assert
assert output == expectedOutput
assertEquals(expectedOutput, output)
}
@Test
@ -182,7 +169,7 @@ class ProxiedEntitiesUtilsTest {
def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input)
// Assert
assert output == expectedOutput
assertEquals(expectedOutput, output)
}
@Test
@ -195,7 +182,7 @@ class ProxiedEntitiesUtilsTest {
def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input)
// Assert
assert output == expectedOutput
assertEquals(expectedOutput, output)
}
@Test
@ -210,7 +197,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Formatted DN: ${formattedDn}")
// Assert
assert formattedDn == expectedFormattedDn
assertEquals(expectedFormattedDn, formattedDn)
}
@Test
@ -224,7 +211,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
// Assert
assert proxiedEntitiesChain == "<${SAFE_USER_NAME_JOHN}><${SAFE_USER_NAME_PROXY_1}>" as String
assertEquals("<${SAFE_USER_NAME_JOHN}><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain)
}
@Test
@ -236,7 +223,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
// Assert
assert proxiedEntitiesChain == "<>"
assertEquals("<>", proxiedEntitiesChain)
}
@Test
@ -250,7 +237,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
// Assert
assert proxiedEntitiesChain == "<><${SAFE_USER_NAME_PROXY_1}>" as String
assertEquals("<><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain)
}
@Test
@ -264,7 +251,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
// Assert
assert proxiedEntitiesChain == "<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}>" as String
assertEquals("<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}>" as String, proxiedEntitiesChain)
}
@Test
@ -278,7 +265,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
// Assert
assert proxiedEntitiesChain == "<${MALICIOUS_USER_NAME_JOHN_ESCAPED}><${SAFE_USER_NAME_PROXY_1}>" as String
assertEquals("<${MALICIOUS_USER_NAME_JOHN_ESCAPED}><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain)
}
@Test
@ -293,7 +280,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Tokenized proxy chain: ${tokenizedNames}")
// Assert
assert tokenizedNames == NAMES
assertEquals(NAMES, tokenizedNames)
}
@Test
@ -308,7 +295,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Tokenized proxy chain: ${tokenizedNames}")
// Assert
assert tokenizedNames == NAMES
assertEquals(NAMES, tokenizedNames)
}
@Test
@ -323,7 +310,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Tokenized proxy chain: ${tokenizedNames}")
// Assert
assert tokenizedNames == NAMES
assertEquals(NAMES, tokenizedNames)
}
@Test
@ -338,7 +325,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Tokenized proxy chain: ${tokenizedNames}")
// Assert
assert tokenizedNames == NAMES
assertEquals(NAMES, tokenizedNames)
}
@Test
@ -353,7 +340,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Tokenized proxy chain: ${tokenizedDns.collect { "\"${it}\"" }}")
// Assert
assert tokenizedDns == DNS
assertEquals(DNS, tokenizedDns)
}
@Test
@ -368,7 +355,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Tokenized proxy chain: ${tokenizedNames}")
// Assert
assert tokenizedNames == NAMES
assertEquals(NAMES, tokenizedNames)
}
@Test
@ -383,9 +370,9 @@ class ProxiedEntitiesUtilsTest {
logger.info("Tokenized proxy chain: ${tokenizedNames.collect { "\"${it}\"" }}")
// Assert
assert tokenizedNames == NAMES
assert tokenizedNames.size() == NAMES.size()
assert !tokenizedNames.contains(SAFE_USER_NAME_JOHN)
assertEquals(NAMES, tokenizedNames)
assertEquals(NAMES.size(), tokenizedNames.size())
assertFalse(tokenizedNames.contains(SAFE_USER_NAME_JOHN))
}
@Test
@ -400,7 +387,7 @@ class ProxiedEntitiesUtilsTest {
logger.info("Tokenized proxy chain: ${tokenizedNames.collect { "\"${it}\"" }}")
// Assert
assert tokenizedNames == TOKENIZED_NAMES
assert tokenizedNames.size() == TOKENIZED_NAMES.size()
assertEquals(TOKENIZED_NAMES, tokenizedNames)
assertEquals(TOKENIZED_NAMES.size(), tokenizedNames.size())
}
}

View File

@ -16,25 +16,23 @@
*/
package org.apache.nifi.web.security.oidc
import com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod
import com.nimbusds.oauth2.sdk.id.Issuer
import com.nimbusds.openid.connect.sdk.SubjectType
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata
import org.apache.nifi.util.NiFiProperties
import org.junit.After
import org.junit.Before
import org.junit.BeforeClass
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.util.concurrent.TimeUnit
@RunWith(JUnit4.class)
class OidcServiceGroovyTest extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertNull
class OidcServiceGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(OidcServiceGroovyTest.class)
private static final Map<String, Object> DEFAULT_NIFI_PROPERTIES = [
@ -61,23 +59,19 @@ class OidcServiceGroovyTest extends GroovyTestCase {
"CIsImVtYWlsIjoib2lkY190ZXN0QG5pZmkuYXBhY2hlLm9yZyJ9" +
".b4NIl0RONKdVLOH0D1eObdwAEX8qX-ExqB8KuKSZFLw"
@BeforeClass
@BeforeAll
static void setUpOnce() throws Exception {
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Before
@BeforeEach
void setUp() throws Exception {
mockNiFiProperties = buildNiFiProperties()
soip = new StandardOidcIdentityProvider(mockNiFiProperties)
}
@After
void teardown() throws Exception {
}
private static NiFiProperties buildNiFiProperties(Map<String, Object> props = [:]) {
def combinedProps = DEFAULT_NIFI_PROPERTIES + props
new NiFiProperties(combinedProps)
@ -100,7 +94,7 @@ class OidcServiceGroovyTest extends GroovyTestCase {
final String cachedJwt = service.getJwt(MOCK_REQUEST_IDENTIFIER)
logger.info("Cached JWT: ${cachedJwt}")
assert cachedJwt == MOCK_JWT
assertEquals(MOCK_JWT, cachedJwt)
}
@Test
@ -121,7 +115,7 @@ class OidcServiceGroovyTest extends GroovyTestCase {
logger.info("Retrieved JWT: ${retrievedJwt}")
// Assert
assert retrievedJwt == MOCK_JWT
assertEquals(MOCK_JWT, retrievedJwt)
}
@Test
@ -149,7 +143,7 @@ class OidcServiceGroovyTest extends GroovyTestCase {
logger.info("Retrieved JWT: ${retrievedJwt}")
// Assert
assert retrievedJwt == null
assertNull(retrievedJwt)
}
private static StandardOidcIdentityProvider buildIdentityProviderWithMockInitializedProvider(Map<String, String> additionalProperties = [:]) {

View File

@ -47,17 +47,19 @@ import net.minidev.json.JSONObject
import org.apache.commons.codec.binary.Base64
import org.apache.nifi.util.NiFiProperties
import org.apache.nifi.util.StringUtils
import org.junit.After
import org.junit.Before
import org.junit.BeforeClass
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
@RunWith(JUnit4.class)
class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertNotNull
import static org.junit.jupiter.api.Assertions.assertNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class StandardOidcIdentityProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(StandardOidcIdentityProviderGroovyTest.class)
private static final Map<String, Object> DEFAULT_NIFI_PROPERTIES = [
@ -80,22 +82,18 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
"X3VuaXRfdGVzdF9hdXRob3JpdHkiLCJhdWQiOiJhbGwiLCJ1c2VybmFtZSI6Im9pZGNfdGVzdCIsImVtYWlsIjoib2lkY19" +
"0ZXN0QG5pZmkuYXBhY2hlLm9yZyJ9.b4NIl0RONKdVLOH0D1eObdwAEX8qX-ExqB8KuKSZFLw"
@BeforeClass
@BeforeAll
static void setUpOnce() throws Exception {
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Before
@BeforeEach
void setUp() throws Exception {
mockNiFiProperties = buildNiFiProperties()
}
@After
void teardown() throws Exception {
}
private static NiFiProperties buildNiFiProperties(Map<String, Object> props = [:]) {
def combinedProps = DEFAULT_NIFI_PROPERTIES + props
new NiFiProperties(combinedProps)
@ -126,7 +124,7 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("Defined claims: ${definedClaims}")
// Assert
assert definedClaims == POPULATED_CLAIM_NAMES
assertEquals(POPULATED_CLAIM_NAMES, definedClaims)
}
@Test
@ -160,9 +158,9 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("Client Auth properties: ${clientAuthentication.getProperties()}")
// Assert
assert clientAuthentication.getClientID() == EXPECTED_CLIENT_AUTHENTICATION.getClientID()
assertEquals(EXPECTED_CLIENT_AUTHENTICATION.getClientID(), clientAuthentication.getClientID())
logger.info("Client secret: ${(clientAuthentication as ClientSecretPost).clientSecret.value}")
assert ((ClientSecretPost) clientAuthentication).getClientSecret() == ((ClientSecretPost) EXPECTED_CLIENT_AUTHENTICATION).getClientSecret()
assertEquals(((ClientSecretPost) EXPECTED_CLIENT_AUTHENTICATION).getClientSecret(), ((ClientSecretPost) clientAuthentication).getClientSecret())
}
@Test
@ -196,10 +194,10 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("Client authentication properties: ${clientAuthentication.properties}")
// Assert
assert clientAuthentication.getClientID() == EXPECTED_CLIENT_AUTHENTICATION.getClientID()
assert clientAuthentication.getMethod() == EXPECTED_CLIENT_AUTHENTICATION.getMethod()
assertEquals(EXPECTED_CLIENT_AUTHENTICATION.getClientID(), clientAuthentication.getClientID())
assertEquals(EXPECTED_CLIENT_AUTHENTICATION.getMethod(), clientAuthentication.getMethod())
logger.info("Client secret: ${(clientAuthentication as ClientSecretBasic).clientSecret.value}")
assert (clientAuthentication as ClientSecretBasic).getClientSecret() == EXPECTED_CLIENT_AUTHENTICATION.clientSecret
assertEquals(EXPECTED_CLIENT_AUTHENTICATION.clientSecret, (clientAuthentication as ClientSecretBasic).getClientSecret())
}
@Test
@ -235,10 +233,10 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("Query: ${URLDecoder.decode(httpRequest.query, "UTF-8")}")
// Assert
assert httpRequest.getMethod().name() == "POST"
assert httpRequest.query =~ "code=${mockCode.value}"
assertEquals("POST", httpRequest.getMethod().name())
assertTrue(httpRequest.query.contains("code=" + mockCode.value))
String encodedUri = URLEncoder.encode("https://localhost/oidc", "UTF-8")
assert httpRequest.query =~ "redirect_uri=${encodedUri}&grant_type=authorization_code"
assertTrue(httpRequest.query.contains("redirect_uri="+encodedUri+"&grant_type=authorization_code"))
}
@Test
@ -262,7 +260,7 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("Identity: ${identity}")
// Assert
assert identity == EXPECTED_IDENTITY
assertEquals(EXPECTED_IDENTITY, identity)
}
@Test
@ -280,14 +278,8 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
HTTPRequest mockUserInfoRequest = mockHttpRequest(responseBody, 200, "HTTP NO USER")
// Act
def msg = shouldFail(IllegalStateException) {
String identity = soip.lookupIdentityInUserInfo(mockUserInfoRequest)
logger.info("Identity: ${identity}")
}
logger.expected(msg)
// Assert
assert msg =~ "Unable to extract identity from the UserInfo token using the claim 'username'."
IllegalStateException ise = assertThrows(IllegalStateException.class, () -> soip.lookupIdentityInUserInfo(mockUserInfoRequest))
assertTrue(ise.getMessage().contains("Unable to extract identity from the UserInfo token using the claim 'username'."))
}
@Test
@ -306,16 +298,9 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
error_uri : "https://localhost/oidc/error"]
HTTPRequest mockUserInfoRequest = mockHttpRequest(errorBody, 500, "HTTP ERROR")
// Act
def msg = shouldFail(RuntimeException) {
String identity = soip.lookupIdentityInUserInfo(mockUserInfoRequest)
logger.info("Identity: ${identity}")
}
logger.expected(msg)
// Assert
assert msg =~ "An error occurred while invoking the UserInfo endpoint: The provided username and password " +
"were not correct"
RuntimeException re = assertThrows(RuntimeException.class, () -> soip.lookupIdentityInUserInfo(mockUserInfoRequest))
assertTrue(re.getMessage().contains("An error occurred while invoking the UserInfo endpoint: The provided username and password " +
"were not correct"))
}
@Test
@ -335,12 +320,12 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
def (String contents, String expiration) = loginToken.tokenize("\\[\\]")
logger.info("Token contents: ${contents} | Expiration: ${expiration}")
assert contents =~ "LoginAuthenticationToken for person@nifi\\.apache\\.org issued by https://accounts\\.issuer\\.com expiring at"
assertTrue(contents.contains("LoginAuthenticationToken for person@nifi.apache.org issued by https://accounts.issuer.com expiring at"))
// Assert expiration
final String[] expList = expiration.split("\\s")
final Long expLong = Long.parseLong(expList[0])
assert expLong <= System.currentTimeMillis() + 10_000
assertTrue(expLong <= System.currentTimeMillis() + 10_000)
}
@Test
@ -360,14 +345,14 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
def (String contents, String expiration) = loginToken.tokenize("\\[\\]")
logger.info("Token contents: ${contents} | Expiration: ${expiration}")
assert contents =~ "LoginAuthenticationToken for person@nifi\\.apache\\.org issued by https://accounts\\.issuer\\.com expiring at"
assertTrue(contents.contains("LoginAuthenticationToken for person@nifi.apache.org issued by https://accounts.issuer.com expiring at"))
// Get the expiration
final ArrayList<String> expires = expiration.split("[\\D*]")
final long exp = Long.parseLong(expires[0])
logger.info("exp: ${exp} ms")
assert exp <= System.currentTimeMillis() + 10_000
assertTrue(exp <= System.currentTimeMillis() + 10_000)
}
@Test
@ -376,7 +361,7 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
StandardOidcIdentityProvider soip = buildIdentityProviderWithMockTokenValidator(
["nifi.security.user.oidc.claim.identifying.user": "email",
"nifi.security.user.oidc.fallback.claims.identifying.user": "upn" ])
String expectedUpn = "xxx@aaddomain";
String expectedUpn = "xxx@aaddomain"
OIDCTokenResponse mockResponse = mockOIDCTokenResponse(["email": null, "upn": expectedUpn])
logger.info("OIDC Token Response with no email and upn: ${mockResponse.dump()}")
@ -387,7 +372,22 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
// Split JWT into components and decode Base64 to JSON
def (String contents, String expiration) = loginToken.tokenize("\\[\\]")
logger.info("Token contents: ${contents} | Expiration: ${expiration}")
assert contents =~ "LoginAuthenticationToken for ${expectedUpn} issued by https://accounts\\.issuer\\.com expiring at"
assertTrue(contents.contains("LoginAuthenticationToken for " + expectedUpn + " issued by https://accounts.issuer.com expiring at"))
}
@Test
void testAuthorizeClientRequestShouldHandleError() {
// Arrange
// Build ID Provider with mock token endpoint URI to make a connection
StandardOidcIdentityProvider soip = buildIdentityProviderWithMockTokenValidator([:])
def responseBody = [id_token: MOCK_JWT, access_token: "some.access.token", refresh_token: "some.refresh.token", token_type: "bearer"]
HTTPRequest mockTokenRequest = mockHttpRequest(responseBody, 500, "HTTP SERVER ERROR")
// Act
RuntimeException re = assertThrows(RuntimeException.class, () -> soip.authorizeClientRequest(mockTokenRequest))
// Assert
assertTrue(re.getMessage().contains("An error occurred while invoking the Token endpoint: null"))
}
@Test
@ -399,11 +399,7 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("OIDC Token Response: ${mockResponse.dump()}")
// Act
def msg = shouldFail(IOException) {
String loginAuthenticationToken = soip.convertOIDCTokenToLoginAuthenticationToken(mockResponse)
logger.info("Login authentication token: ${loginAuthenticationToken}")
}
logger.expected(msg)
assertThrows(IOException.class, () -> soip.convertOIDCTokenToLoginAuthenticationToken(mockResponse))
}
@Test
@ -420,27 +416,7 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("Token Response: ${tokenResponse.dump()}")
// Assert
assert tokenResponse
}
@Test
void testAuthorizeClientRequestShouldHandleError() {
// Arrange
// Build ID Provider with mock token endpoint URI to make a connection
StandardOidcIdentityProvider soip = buildIdentityProviderWithMockTokenValidator([:])
def responseBody = [id_token: MOCK_JWT, access_token: "some.access.token", refresh_token: "some.refresh.token", token_type: "bearer"]
HTTPRequest mockTokenRequest = mockHttpRequest(responseBody, 500, "HTTP SERVER ERROR")
// Act
def msg = shouldFail(RuntimeException) {
def nifiToken = soip.authorizeClientRequest(mockTokenRequest)
logger.info("NiFi token: ${nifiToken}")
}
logger.expected(msg)
// Assert
assert msg =~ "An error occurred while invoking the Token endpoint: null"
assertNotNull(tokenResponse)
}
@Test
@ -477,7 +453,7 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("Access token: ${accessTokenString}")
// Assert
assert accessTokenString
assertNotNull(accessTokenString)
}
@Test
@ -513,7 +489,7 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("Access Token: ${accessTokenString}")
// Assert
assert accessTokenString == null
assertNull(accessTokenString)
}
@Test
@ -543,13 +519,9 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
OIDCTokens mockOidcTokens = new OIDCTokens(mockJwt, mockAccessToken, mockRefreshToken)
// Act
def msg = shouldFail(Exception) {
soip.validateAccessToken(mockOidcTokens)
}
logger.expected(msg)
Exception e = assertThrows(Exception.class, () -> soip.validateAccessToken(mockOidcTokens))
// Assert
assert msg =~ "Unable to validate the Access Token: Access token hash \\(at_hash\\) mismatch"
assertTrue(e.getMessage().contains("Unable to validate the Access Token: Access token hash (at_hash) mismatch"))
}
@Test
@ -584,8 +556,8 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("ID Token: ${idTokenString}")
// Assert
assert idTokenString
assert idTokenString == EXPECTED_ID_TOKEN
assertNotNull(idTokenString)
assertEquals(EXPECTED_ID_TOKEN, idTokenString)
// Assert that 'email:person@nifi.apache.org' is present
def (String header, String payload) = idTokenString.split("\\.")
@ -593,7 +565,7 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
final String payloadString = new String(idTokenBytes, "UTF-8")
logger.info(payloadString)
assert payloadString =~ "\"email\":\"person@nifi\\.apache\\.org\""
assertTrue(payloadString.contains("\"email\":\"person@nifi.apache.org\""))
}
@Test
@ -616,8 +588,8 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("ID Token Claims Set: ${claimsSetString}")
// Assert
assert claimsSet
assert claimsSetString =~ "\"email\":\"person@nifi\\.apache\\.org\""
assertNotNull(claimsSet)
assertTrue(claimsSetString.contains("\"email\":\"person@nifi.apache.org\""))
}
@Test
@ -652,13 +624,13 @@ class StandardOidcIdentityProviderGroovyTest extends GroovyTestCase {
logger.info("OIDC Tokens: ${oidcTokens.toJSONObject()}")
// Assert
assert oidcTokens
assertNotNull(oidcTokens)
// Assert ID Tokens match
final JSONObject oidcJson = oidcTokens.toJSONObject()
final String idToken = oidcJson["id_token"]
logger.info("ID Token String: ${idToken}")
assert idToken == EXPECTED_ID_TOKEN
assertEquals(EXPECTED_ID_TOKEN, idToken)
}
private StandardOidcIdentityProvider buildIdentityProviderWithMockTokenValidator(Map<String, String> additionalProperties = [:]) {

View File

@ -23,13 +23,10 @@ import org.eclipse.jetty.server.Server
import org.eclipse.jetty.servlet.FilterHolder
import org.eclipse.jetty.servlet.ServletContextHandler
import org.eclipse.jetty.servlet.ServletHolder
import org.junit.After
import org.junit.Assert
import org.junit.Before
import org.junit.BeforeClass
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
@ -41,8 +38,10 @@ import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
import java.util.concurrent.TimeUnit
@RunWith(JUnit4.class)
class ContentLengthFilterTest extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertTrue
class ContentLengthFilterTest {
private static final Logger logger = LoggerFactory.getLogger(ContentLengthFilterTest.class)
private static final int MAX_CONTENT_LENGTH = 1000
@ -70,12 +69,12 @@ class ContentLengthFilterTest extends GroovyTestCase {
}
}
@Before
@BeforeEach
void setUp() {
createSimpleReadServer()
}
@After
@AfterEach
void tearDown() {
stopServer()
}
@ -135,11 +134,9 @@ class ContentLengthFilterTest extends GroovyTestCase {
@Test
void testRequestsWithMissingContentLengthHeader() throws Exception {
createSimpleReadServer()
// This shows that the ContentLengthFilter allows a request that does not have a content-length header.
String response = localConnector.getResponse("POST / HTTP/1.0\r\n\r\n")
Assert.assertFalse(StringUtils.containsIgnoreCase(response, "411 Length Required"))
assertFalse(StringUtils.containsIgnoreCase(response, "411 Length Required"))
}
/**
@ -149,7 +146,6 @@ class ContentLengthFilterTest extends GroovyTestCase {
@Test
void testShouldRejectRequestWithLongContentLengthHeader() throws Exception {
// Arrange
createSimpleReadServer()
final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD)
logger.info("Making request with CL: ${LARGE_CLAIM_SIZE_BYTES} and actual length: ${LARGE_PAYLOAD.length()}")
@ -158,7 +154,7 @@ class ContentLengthFilterTest extends GroovyTestCase {
logResponse(response)
// Assert
assert response =~ "413 Payload Too Large"
assertTrue(response.contains("413 Payload Too Large"))
}
/**
@ -168,8 +164,6 @@ class ContentLengthFilterTest extends GroovyTestCase {
@Test
void testShouldRejectRequestWithLongContentLengthHeaderAndSmallPayload() throws Exception {
// Arrange
createSimpleReadServer()
String incompletePayload = "1" * (SMALL_CLAIM_SIZE_BYTES / 2)
final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, incompletePayload)
logger.info("Making request with CL: ${LARGE_CLAIM_SIZE_BYTES} and actual length: ${incompletePayload.length()}")
@ -179,7 +173,7 @@ class ContentLengthFilterTest extends GroovyTestCase {
logResponse(response)
// Assert
assert response =~ "413 Payload Too Large"
assertTrue(response.contains("413 Payload Too Large"))
}
/**
@ -190,7 +184,6 @@ class ContentLengthFilterTest extends GroovyTestCase {
@Test
void testShouldRejectRequestWithSmallContentLengthHeaderAndLargePayload() throws Exception {
// Arrange
createSimpleReadServer()
final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, LARGE_PAYLOAD)
logger.info("Making request with CL: ${SMALL_CLAIM_SIZE_BYTES} and actual length: ${LARGE_PAYLOAD.length()}")
@ -199,10 +192,9 @@ class ContentLengthFilterTest extends GroovyTestCase {
logResponse(response)
// Assert
assert response =~ "200"
assert response =~ "Bytes-Read: ${SMALL_CLAIM_SIZE_BYTES}"
assert response =~ "Read ${SMALL_CLAIM_SIZE_BYTES} bytes"
assertTrue(response.contains("200"))
assertTrue(response.contains("Bytes-Read: " + SMALL_CLAIM_SIZE_BYTES))
assertTrue(response.contains("Read " + SMALL_CLAIM_SIZE_BYTES + " bytes"))
}
/**
@ -212,8 +204,6 @@ class ContentLengthFilterTest extends GroovyTestCase {
@Test
void testShouldTimeoutRequestWithSmallContentLengthHeaderAndSmallerPayload() throws Exception {
// Arrange
createSimpleReadServer()
String smallerPayload = SMALL_PAYLOAD[0..(SMALL_PAYLOAD.length() / 2)]
final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, smallerPayload)
logger.info("Making request with CL: ${SMALL_CLAIM_SIZE_BYTES} and actual length: ${smallerPayload.length()}")
@ -223,15 +213,13 @@ class ContentLengthFilterTest extends GroovyTestCase {
logResponse(response)
// Assert
assert response =~ "500 Server Error"
assert response =~ "Timeout"
assertTrue(response.contains("500 Server Error"))
assertTrue(response.contains("Timeout"))
}
@Test
void testFilterShouldAllowSiteToSiteTransfer() throws Exception {
// Arrange
createSimpleReadServer()
final String SITE_TO_SITE_POST_REQUEST = "POST /nifi-api/data-transfer/input-ports HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s"
final String siteToSiteRequest = String.format(SITE_TO_SITE_POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD)
@ -242,7 +230,7 @@ class ContentLengthFilterTest extends GroovyTestCase {
logResponse(response)
// Assert
assert response =~ "200 OK"
assertTrue(response.contains("200 OK"))
}
@Test
@ -287,11 +275,11 @@ class ContentLengthFilterTest extends GroovyTestCase {
String form = "a=" + "1" * FORM_CONTENT_SIZE
String response = localConnector.getResponse(String.format(FORM_REQUEST, form.length(), form))
logResponse(response)
assert response =~ "413 Payload Too Large"
assertTrue(response.contains("413 Payload Too Large"))
// But it does not catch requests like this:
response = localConnector.getResponse(String.format(POST_REQUEST, form.length(), form + form))
assert response =~ "417 Read Too Many Bytes"
assertTrue(response.contains("417 Read Too Many Bytes"))
}
}

View File

@ -19,8 +19,7 @@ package org.apache.nifi.web.security;
import org.apache.nifi.authorization.Authorizer;
import org.apache.nifi.authorization.util.IdentityMapping;
import org.apache.nifi.util.NiFiProperties;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.junit.jupiter.api.Test;
import org.mockito.stubbing.Answer;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
@ -28,7 +27,7 @@ import org.springframework.security.core.AuthenticationException;
import java.util.List;
import java.util.Properties;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -173,12 +172,7 @@ public class NiFiAuthenticationProviderTest {
final NiFiProperties nifiProperties = mock(NiFiProperties.class);
when(nifiProperties.getPropertyKeys()).thenReturn(properties.stringPropertyNames());
when(nifiProperties.getProperty(anyString())).then(new Answer<String>() {
@Override
public String answer(InvocationOnMock invocationOnMock) throws Throwable {
return properties.getProperty((String)invocationOnMock.getArguments()[0]);
}
});
when(nifiProperties.getProperty(anyString())).then((Answer<String>) invocationOnMock -> properties.getProperty((String)invocationOnMock.getArguments()[0]));
return nifiProperties;
}

View File

@ -16,16 +16,16 @@
*/
package org.apache.nifi.web.security;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestProxiedEntitiesUtils {

View File

@ -22,21 +22,18 @@ import org.apache.nifi.util.NiFiProperties;
import org.apache.nifi.util.StringUtils;
import org.apache.nifi.web.security.InvalidAuthenticationException;
import org.apache.nifi.web.security.token.NiFiAuthenticationToken;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class NiFiAnonymousAuthenticationProviderTest {
private static final Logger logger = LoggerFactory.getLogger(NiFiAnonymousAuthenticationProviderTest.class);
@Test
public void testAnonymousDisabledNotSecure() throws Exception {
public void testAnonymousDisabledNotSecure() {
final NiFiProperties nifiProperties = Mockito.mock(NiFiProperties.class);
when(nifiProperties.isAnonymousAuthenticationAllowed()).thenReturn(false);
@ -50,7 +47,7 @@ public class NiFiAnonymousAuthenticationProviderTest {
}
@Test
public void testAnonymousEnabledNotSecure() throws Exception {
public void testAnonymousEnabledNotSecure() {
final NiFiProperties nifiProperties = Mockito.mock(NiFiProperties.class);
when(nifiProperties.isAnonymousAuthenticationAllowed()).thenReturn(true);
@ -63,8 +60,8 @@ public class NiFiAnonymousAuthenticationProviderTest {
assertTrue(userDetails.getNiFiUser().isAnonymous());
}
@Test(expected = InvalidAuthenticationException.class)
public void testAnonymousDisabledSecure() throws Exception {
@Test
public void testAnonymousDisabledSecure() {
final NiFiProperties nifiProperties = Mockito.mock(NiFiProperties.class);
when(nifiProperties.isAnonymousAuthenticationAllowed()).thenReturn(false);
@ -72,11 +69,11 @@ public class NiFiAnonymousAuthenticationProviderTest {
final NiFiAnonymousAuthenticationRequestToken authenticationRequest = new NiFiAnonymousAuthenticationRequestToken(true, StringUtils.EMPTY);
anonymousAuthenticationProvider.authenticate(authenticationRequest);
assertThrows(InvalidAuthenticationException.class, () -> anonymousAuthenticationProvider.authenticate(authenticationRequest));
}
@Test
public void testAnonymousEnabledSecure() throws Exception {
public void testAnonymousEnabledSecure() {
final NiFiProperties nifiProperties = Mockito.mock(NiFiProperties.class);
when(nifiProperties.isAnonymousAuthenticationAllowed()).thenReturn(true);

View File

@ -30,11 +30,11 @@ import org.apache.nifi.idp.IdpUserGroup;
import org.apache.nifi.util.NiFiProperties;
import org.apache.nifi.util.StringUtils;
import org.apache.nifi.web.security.token.NiFiAuthenticationToken;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.oauth2.jwt.Jwt;
import java.util.Collections;
@ -42,12 +42,12 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
@ExtendWith(MockitoExtension.class)
public class StandardJwtAuthenticationConverterTest {
private static final String USERNAME = "NiFi";
@ -76,7 +76,7 @@ public class StandardJwtAuthenticationConverterTest {
private StandardJwtAuthenticationConverter converter;
@Before
@BeforeEach
public void setConverter() {
final Map<String, String> properties = new HashMap<>();
final NiFiProperties niFiProperties = NiFiProperties.createBasicNiFiProperties(StringUtils.EMPTY, properties);

View File

@ -16,22 +16,22 @@
*/
package org.apache.nifi.web.security.jwt.jws;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.junit.jupiter.MockitoExtension;
import java.time.Instant;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
@ExtendWith(MockitoExtension.class)
public class StandardJwsSignerProviderTest {
private static final String KEY_IDENTIFIER = UUID.randomUUID().toString();
@ -49,7 +49,7 @@ public class StandardJwsSignerProviderTest {
private StandardJwsSignerProvider provider;
@Before
@BeforeEach
public void setProvider() {
provider = new StandardJwsSignerProvider(signingKeyListener);
when(jwsSignerContainer.getKeyIdentifier()).thenReturn(KEY_IDENTIFIER);
@ -61,7 +61,7 @@ public class StandardJwsSignerProviderTest {
final Instant expiration = Instant.now();
final JwsSignerContainer container = provider.getJwsSignerContainer(expiration);
assertEquals("JWS Signer Container not matched", jwsSignerContainer, container);
assertEquals(jwsSignerContainer, container,"JWS Signer Container not matched");
verify(signingKeyListener).onSigningKeyUsed(keyIdentifierCaptor.capture(), expirationCaptor.capture());
assertEquals(KEY_IDENTIFIER, keyIdentifierCaptor.getValue());

View File

@ -20,20 +20,20 @@ import com.nimbusds.jose.JWSAlgorithm;
import org.apache.nifi.web.security.jwt.jws.JwsSignerContainer;
import org.apache.nifi.web.security.jwt.jws.SignerListener;
import org.apache.nifi.web.security.jwt.key.VerificationKeyListener;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.junit.jupiter.MockitoExtension;
import java.security.Key;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.verify;
@RunWith(MockitoJUnitRunner.class)
@ExtendWith(MockitoExtension.class)
public class KeyGenerationCommandTest {
private static final String KEY_ALGORITHM = "RSA";
@ -56,7 +56,7 @@ public class KeyGenerationCommandTest {
private KeyGenerationCommand command;
@Before
@BeforeEach
public void setCommand() {
command = new KeyGenerationCommand(signerListener, verificationKeyListener);
}

View File

@ -22,13 +22,13 @@ import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import org.apache.nifi.components.state.Scope;
import org.apache.nifi.components.state.StateManager;
import org.apache.nifi.components.state.StateMap;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.junit.jupiter.MockitoExtension;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
@ -38,13 +38,13 @@ import java.util.Collections;
import java.util.Map;
import java.util.UUID;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
@ExtendWith(MockitoExtension.class)
public class StandardVerificationKeyServiceTest {
private static final String ID = UUID.randomUUID().toString();
@ -72,11 +72,9 @@ public class StandardVerificationKeyServiceTest {
private StandardVerificationKeyService service;
@Before
@BeforeEach
public void setService() {
service = new StandardVerificationKeyService(stateManager);
when(key.getAlgorithm()).thenReturn(ALGORITHM);
when(key.getEncoded()).thenReturn(ENCODED);
}
@Test
@ -90,11 +88,13 @@ public class StandardVerificationKeyServiceTest {
verify(stateManager).setState(stateCaptor.capture(), eq(SCOPE));
final Map<String, String> stateSaved = stateCaptor.getValue();
assertTrue("Expired Key not deleted", stateSaved.isEmpty());
assertTrue(stateSaved.isEmpty(), "Expired Key not deleted");
}
@Test
public void testSave() throws IOException {
when(key.getAlgorithm()).thenReturn(ALGORITHM);
when(key.getEncoded()).thenReturn(ENCODED);
when(stateManager.getState(eq(SCOPE))).thenReturn(stateMap);
when(stateMap.toMap()).thenReturn(Collections.emptyMap());
@ -104,7 +104,7 @@ public class StandardVerificationKeyServiceTest {
verify(stateManager).setState(stateCaptor.capture(), eq(SCOPE));
final Map<String, String> stateSaved = stateCaptor.getValue();
final String serialized = stateSaved.get(ID);
assertNotNull("Serialized Key not found", serialized);
assertNotNull(serialized,"Serialized Key not found");
}
@Test
@ -121,7 +121,7 @@ public class StandardVerificationKeyServiceTest {
verify(stateManager).setState(stateCaptor.capture(), eq(SCOPE));
final Map<String, String> stateSaved = stateCaptor.getValue();
final String saved = stateSaved.get(ID);
assertNotNull("Serialized Key not found", saved);
assertNotNull(saved, "Serialized Key not found");
}
private String getSerializedVerificationKey(final Instant expiration) throws JsonProcessingException {

View File

@ -18,21 +18,21 @@ package org.apache.nifi.web.security.jwt.resolver;
import org.apache.nifi.web.security.http.SecurityCookieName;
import org.apache.nifi.web.security.http.SecurityHeader;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.junit.jupiter.MockitoExtension;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
@ExtendWith(MockitoExtension.class)
public class StandardBearerTokenResolverTest {
private static final String BEARER_TOKEN = "TOKEN";
@ -41,7 +41,7 @@ public class StandardBearerTokenResolverTest {
@Mock
private HttpServletRequest request;
@Before
@BeforeEach
public void setResolver() {
resolver = new StandardBearerTokenResolver();
}

View File

@ -16,22 +16,22 @@
*/
package org.apache.nifi.web.security.jwt.revocation;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jwt.Jwt;
import java.util.UUID;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
@ExtendWith(MockitoExtension.class)
public class JwtRevocationValidatorTest {
private static final String ID = UUID.randomUUID().toString();
@ -48,7 +48,7 @@ public class JwtRevocationValidatorTest {
private JwtRevocationValidator validator;
@Before
@BeforeEach
public void setValidator() {
validator = new JwtRevocationValidator(jwtRevocationService);
jwt = Jwt.withTokenValue(TOKEN).header(TYPE_FIELD, JWT_TYPE).jti(ID).build();

View File

@ -16,11 +16,11 @@
*/
package org.apache.nifi.web.security.jwt.revocation;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
@ -29,10 +29,10 @@ import java.util.UUID;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
@ExtendWith(MockitoExtension.class)
public class StandardJwtLogoutListenerTest {
private static final String ID = UUID.randomUUID().toString();
@ -54,7 +54,7 @@ public class StandardJwtLogoutListenerTest {
private StandardJwtLogoutListener listener;
@Before
@BeforeEach
public void setListener() {
listener = new StandardJwtLogoutListener(jwtDecoder, jwtRevocationService);
jwt = Jwt.withTokenValue(TOKEN).header(TYPE_FIELD, JWT_TYPE).jti(ID).expiresAt(EXPIRES).build();
@ -63,8 +63,8 @@ public class StandardJwtLogoutListenerTest {
@Test
public void testLogoutBearerTokenNullZeroInteractions() {
listener.logout(null);
verifyZeroInteractions(jwtDecoder);
verifyZeroInteractions(jwtRevocationService);
verifyNoInteractions(jwtDecoder);
verifyNoInteractions(jwtRevocationService);
}
@Test

View File

@ -19,13 +19,13 @@ package org.apache.nifi.web.security.jwt.revocation;
import org.apache.nifi.components.state.Scope;
import org.apache.nifi.components.state.StateManager;
import org.apache.nifi.components.state.StateMap;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.junit.jupiter.MockitoExtension;
import java.io.IOException;
import java.time.Instant;
@ -33,14 +33,14 @@ import java.util.Collections;
import java.util.Map;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
@ExtendWith(MockitoExtension.class)
public class StandardJwtRevocationServiceTest {
private static final String ID = UUID.randomUUID().toString();
@ -59,7 +59,7 @@ public class StandardJwtRevocationServiceTest {
private StandardJwtRevocationService service;
@Before
@BeforeEach
public void setService() {
service = new StandardJwtRevocationService(stateManager);
}
@ -73,7 +73,7 @@ public class StandardJwtRevocationServiceTest {
verify(stateManager).setState(stateCaptor.capture(), eq(SCOPE));
final Map<String, String> stateSaved = stateCaptor.getValue();
assertTrue("Expired Key not deleted", stateSaved.isEmpty());
assertTrue(stateSaved.isEmpty(), "Expired Key not deleted");
}
@Test
@ -103,6 +103,6 @@ public class StandardJwtRevocationServiceTest {
verify(stateManager).setState(stateCaptor.capture(), eq(SCOPE));
final Map<String, String> stateSaved = stateCaptor.getValue();
final String saved = stateSaved.get(ID);
assertEquals("Expiration not matched", expiration.toString(), saved);
assertEquals(expiration.toString(), saved, "Expiration not matched");
}
}

View File

@ -17,16 +17,16 @@
package org.apache.nifi.web.security.knox;
import org.apache.nifi.util.NiFiProperties;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -36,7 +36,7 @@ public class KnoxAuthenticationFilterTest {
private KnoxAuthenticationFilter knoxAuthenticationFilter;
@Before
@BeforeEach
public void setUp() throws Exception {
final NiFiProperties nifiProperties = Mockito.mock(NiFiProperties.class);
when(nifiProperties.isKnoxSsoEnabled()).thenReturn(true);
@ -47,14 +47,14 @@ public class KnoxAuthenticationFilterTest {
}
@Test
public void testInsecureHttp() throws Exception {
public void testInsecureHttp() {
final HttpServletRequest request = mock(HttpServletRequest.class);
when(request.isSecure()).thenReturn(false);
assertNull(knoxAuthenticationFilter.attemptAuthentication(request));
}
@Test
public void testNullCookies() throws Exception {
public void testNullCookies() {
final HttpServletRequest request = mock(HttpServletRequest.class);
when(request.isSecure()).thenReturn(true);
when(request.getCookies()).thenReturn(null);
@ -62,7 +62,7 @@ public class KnoxAuthenticationFilterTest {
}
@Test
public void testNoCookies() throws Exception {
public void testNoCookies() {
final HttpServletRequest request = mock(HttpServletRequest.class);
when(request.isSecure()).thenReturn(true);
when(request.getCookies()).thenReturn(new Cookie[] {});
@ -70,7 +70,7 @@ public class KnoxAuthenticationFilterTest {
}
@Test
public void testWrongCookieName() throws Exception {
public void testWrongCookieName() {
final String jwt = "my-jwt";
final Cookie knoxCookie = mock(Cookie.class);
@ -86,7 +86,7 @@ public class KnoxAuthenticationFilterTest {
}
@Test
public void testKnoxCookie() throws Exception {
public void testKnoxCookie() {
final String jwt = "my-jwt";
final Cookie knoxCookie = mock(Cookie.class);

View File

@ -24,12 +24,10 @@ import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT;
import com.nimbusds.oauth2.sdk.id.Audience;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.JWTID;
import org.apache.commons.lang3.SystemUtils;
import org.apache.nifi.web.security.InvalidAuthenticationException;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnOs;
import org.junit.jupiter.api.condition.OS;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
@ -41,40 +39,38 @@ import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.junit.Assert.assertFalse;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@DisabledOnOs({OS.WINDOWS})
public class KnoxServiceTest {
private static final String AUDIENCE = "https://apache-knox/token";
private static final String AUDIENCE_2 = "https://apache-knox-2/token";
@BeforeClass
public static void setupClass() {
Assume.assumeTrue("Test only runs on *nix", !SystemUtils.IS_OS_WINDOWS);
}
@Test(expected = IllegalStateException.class)
public void testKnoxSsoNotEnabledGetKnoxUrl() throws Exception {
@Test
public void testKnoxSsoNotEnabledGetKnoxUrl() {
final KnoxConfiguration configuration = mock(KnoxConfiguration.class);
when(configuration.isKnoxEnabled()).thenReturn(false);
final KnoxService service = new KnoxService(configuration);
assertFalse(service.isKnoxEnabled());
service.getKnoxUrl();
assertThrows(IllegalStateException.class, service::getKnoxUrl);
}
@Test(expected = IllegalStateException.class)
public void testKnoxSsoNotEnabledGetAuthenticatedFromToken() throws Exception {
@Test
public void testKnoxSsoNotEnabledGetAuthenticatedFromToken() {
final KnoxConfiguration configuration = mock(KnoxConfiguration.class);
when(configuration.isKnoxEnabled()).thenReturn(false);
final KnoxService service = new KnoxService(configuration);
assertFalse(service.isKnoxEnabled());
service.getAuthenticationFromToken("jwt-token-value");
assertThrows(IllegalStateException.class, () -> service.getAuthenticationFromToken("jwt-token-value"));
}
private JWTAuthenticationClaimsSet getAuthenticationClaimsSet(final String subject, final String audience, final Date expiration) {
@ -103,10 +99,10 @@ public class KnoxServiceTest {
final KnoxConfiguration configuration = getConfiguration(publicKey);
final KnoxService service = new KnoxService(configuration);
Assert.assertEquals(subject, service.getAuthenticationFromToken(privateKeyJWT.getClientAssertion().serialize()));
assertEquals(subject, service.getAuthenticationFromToken(privateKeyJWT.getClientAssertion().serialize()));
}
@Test(expected = InvalidAuthenticationException.class)
@Test
public void testBadSignedJwt() throws Exception {
final String subject = "user-1";
final Date expiration = new Date(System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(5, TimeUnit.SECONDS));
@ -127,10 +123,10 @@ public class KnoxServiceTest {
final KnoxConfiguration configuration = getConfiguration(publicKey2);
final KnoxService service = new KnoxService(configuration);
service.getAuthenticationFromToken(privateKeyJWT.getClientAssertion().serialize());
assertThrows(InvalidAuthenticationException.class, () -> service.getAuthenticationFromToken(privateKeyJWT.getClientAssertion().serialize()));
}
@Test(expected = ParseException.class)
@Test
public void testPlainJwt() throws Exception {
final KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA");
final KeyPair pair = keyGen.generateKeyPair();
@ -147,10 +143,10 @@ public class KnoxServiceTest {
final KnoxConfiguration configuration = getConfiguration(publicKey);
final KnoxService service = new KnoxService(configuration);
service.getAuthenticationFromToken(plainJWT.serialize());
assertThrows(ParseException.class, () -> service.getAuthenticationFromToken(plainJWT.serialize()));
}
@Test(expected = InvalidAuthenticationException.class)
@Test
public void testExpiredJwt() throws Exception {
final String subject = "user-1";
@ -171,7 +167,7 @@ public class KnoxServiceTest {
final KnoxConfiguration configuration = getConfiguration(publicKey);
final KnoxService service = new KnoxService(configuration);
service.getAuthenticationFromToken(privateKeyJWT.getClientAssertion().serialize());
assertThrows(InvalidAuthenticationException.class, () -> service.getAuthenticationFromToken(privateKeyJWT.getClientAssertion().serialize()));
}
@Test
@ -191,10 +187,10 @@ public class KnoxServiceTest {
when(configuration.getAudiences()).thenReturn(null);
final KnoxService service = new KnoxService(configuration);
Assert.assertEquals(subject, service.getAuthenticationFromToken(privateKeyJWT.getClientAssertion().serialize()));
assertEquals(subject, service.getAuthenticationFromToken(privateKeyJWT.getClientAssertion().serialize()));
}
@Test(expected = InvalidAuthenticationException.class)
@Test
public void testInvalidAudience() throws Exception {
final String subject = "user-1";
final Date expiration = new Date(System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(5, TimeUnit.SECONDS));
@ -209,11 +205,10 @@ public class KnoxServiceTest {
final KnoxConfiguration configuration = getConfiguration(publicKey);
final KnoxService service = new KnoxService(configuration);
Assert.assertEquals(subject, service.getAuthenticationFromToken(privateKeyJWT.getClientAssertion().serialize()));
assertThrows(InvalidAuthenticationException.class, () -> service.getAuthenticationFromToken(privateKeyJWT.getClientAssertion().serialize()));
}
private KnoxConfiguration getConfiguration(final RSAPublicKey publicKey) throws Exception {
private KnoxConfiguration getConfiguration(final RSAPublicKey publicKey) {
final KnoxConfiguration configuration = mock(KnoxConfiguration.class);
when(configuration.isKnoxEnabled()).thenReturn(true);
when(configuration.getKnoxUrl()).thenReturn("knox-sso-url");

View File

@ -16,18 +16,18 @@
*/
package org.apache.nifi.web.security.logout;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
public class TestLogoutRequestManager {
private LogoutRequestManager logoutRequestManager;
@Before
@BeforeEach
public void setup() {
logoutRequestManager = new LogoutRequestManager();
}

View File

@ -20,14 +20,14 @@ import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationGrant;
import com.nimbusds.oauth2.sdk.id.State;
import java.io.IOException;
import java.net.URI;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -37,23 +37,23 @@ public class OidcServiceTest {
public static final String TEST_REQUEST_IDENTIFIER = "test-request-identifier";
public static final String TEST_STATE = "test-state";
@Test(expected = IllegalStateException.class)
@Test
public void testOidcNotEnabledCreateState() {
final OidcService service = getServiceWithNoOidcSupport();
service.createState(TEST_REQUEST_IDENTIFIER);
assertThrows(IllegalStateException.class, () -> service.createState(TEST_REQUEST_IDENTIFIER));
}
@Test(expected = IllegalStateException.class)
@Test
public void testCreateStateMultipleInvocations() {
final OidcService service = getServiceWithOidcSupport();
service.createState(TEST_REQUEST_IDENTIFIER);
service.createState(TEST_REQUEST_IDENTIFIER);
assertThrows(IllegalStateException.class, () -> service.createState(TEST_REQUEST_IDENTIFIER));
}
@Test(expected = IllegalStateException.class)
@Test
public void testOidcNotEnabledValidateState() {
final OidcService service = getServiceWithNoOidcSupport();
service.isStateValid(TEST_REQUEST_IDENTIFIER, new State(TEST_STATE));
assertThrows(IllegalStateException.class, () -> service.isStateValid(TEST_REQUEST_IDENTIFIER, new State(TEST_STATE)));
}
@Test
@ -79,7 +79,7 @@ public class OidcServiceTest {
assertFalse(service.isStateValid(TEST_REQUEST_IDENTIFIER, state));
}
@Test(expected = IllegalStateException.class)
@Test
public void testStoreJwtMultipleInvocation() {
final OidcService service = getServiceWithOidcSupport();
@ -93,31 +93,31 @@ public class OidcServiceTest {
"uYXBhY2hlLm9yZyJ9.nlYhplDLXeGAwW62rJ_ZnEaG7nxEB4TbaJNK-_pC4WQ";
service.storeJwt(TEST_REQUEST_IDENTIFIER, TEST_JWT1);
service.storeJwt(TEST_REQUEST_IDENTIFIER, TEST_JWT2);
assertThrows(IllegalStateException.class, () -> service.storeJwt(TEST_REQUEST_IDENTIFIER, TEST_JWT2));
}
@Test(expected = IllegalStateException.class)
public void testOidcNotEnabledExchangeCodeForLoginAuthenticationToken() throws Exception {
@Test
public void testOidcNotEnabledExchangeCodeForLoginAuthenticationToken() {
final OidcService service = getServiceWithNoOidcSupport();
service.exchangeAuthorizationCodeForLoginAuthenticationToken(getAuthorizationGrant());
assertThrows(IllegalStateException.class, () -> service.exchangeAuthorizationCodeForLoginAuthenticationToken(getAuthorizationGrant()));
}
@Test(expected = IllegalStateException.class)
public void testOidcNotEnabledExchangeCodeForAccessToken() throws Exception {
@Test
public void testOidcNotEnabledExchangeCodeForAccessToken() {
final OidcService service = getServiceWithNoOidcSupport();
service.exchangeAuthorizationCodeForAccessToken(getAuthorizationGrant());
assertThrows(IllegalStateException.class, () ->service.exchangeAuthorizationCodeForAccessToken(getAuthorizationGrant()));
}
@Test(expected = IllegalStateException.class)
public void testOidcNotEnabledExchangeCodeForIdToken() throws IOException {
@Test
public void testOidcNotEnabledExchangeCodeForIdToken() {
final OidcService service = getServiceWithNoOidcSupport();
service.exchangeAuthorizationCodeForIdToken(getAuthorizationGrant());
assertThrows(IllegalStateException.class, () -> service.exchangeAuthorizationCodeForIdToken(getAuthorizationGrant()));
}
@Test(expected = IllegalStateException.class)
@Test
public void testOidcNotEnabledGetJwt() {
final OidcService service = getServiceWithNoOidcSupport();
service.getJwt(TEST_REQUEST_IDENTIFIER);
assertThrows(IllegalStateException.class, () -> service.getJwt(TEST_REQUEST_IDENTIFIER));
}
private OidcService getServiceWithNoOidcSupport() {

View File

@ -19,14 +19,14 @@ package org.apache.nifi.web.security.oidc;
import com.nimbusds.oauth2.sdk.Scope;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.nifi.util.NiFiProperties;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -42,7 +42,7 @@ public class StandardOidcIdentityProviderTest {
Scope scope = provider.getScope();
// two additional scopes are set, two (openid, email) are hard-coded
assertEquals(scope.toArray().length, 4);
assertEquals(4, scope.toArray().length);
assertTrue(scope.contains("openid"));
assertTrue(scope.contains("email"));
assertTrue(scope.contains(additionalScope_profile));
@ -59,7 +59,7 @@ public class StandardOidcIdentityProviderTest {
// three additional scopes are set but one is duplicated and mustn't be returned; note that there is
// another one inserted in between the duplicated; two (openid, email) are hard-coded
assertEquals(scope.toArray().length, 4);
assertEquals(4, scope.toArray().length);
}
private StandardOidcIdentityProvider createOidcProviderWithAdditionalScopes(String... additionalScopes) throws IllegalAccessException {

View File

@ -28,8 +28,8 @@ import org.apache.nifi.util.NiFiProperties;
import org.apache.nifi.web.security.InvalidAuthenticationException;
import org.apache.nifi.web.security.UntrustedProxyException;
import org.apache.nifi.web.security.token.NiFiAuthenticationToken;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.security.Principal;
import java.security.cert.X509Certificate;
@ -40,10 +40,12 @@ import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -68,7 +70,7 @@ public class X509AuthenticationProviderTest {
private SubjectDnX509PrincipalExtractor extractor;
private Authorizer authorizer;
@Before
@BeforeEach
public void setup() {
System.clearProperty(NiFiProperties.PROPERTIES_FILE_PATH);
@ -100,9 +102,9 @@ public class X509AuthenticationProviderTest {
x509AuthenticationProvider = new X509AuthenticationProvider(certificateIdentityProvider, authorizer, NiFiProperties.createBasicNiFiProperties(null));
}
@Test(expected = InvalidAuthenticationException.class)
@Test
public void testInvalidCertificate() {
x509AuthenticationProvider.authenticate(getX509Request("", INVALID_CERTIFICATE));
assertThrows(InvalidAuthenticationException.class, () -> x509AuthenticationProvider.authenticate(getX509Request("", INVALID_CERTIFICATE)));
}
@Test
@ -115,9 +117,9 @@ public class X509AuthenticationProviderTest {
assertFalse(user.isAnonymous());
}
@Test(expected = UntrustedProxyException.class)
@Test
public void testUntrustedProxy() {
x509AuthenticationProvider.authenticate(getX509Request(buildProxyChain(IDENTITY_1), UNTRUSTED_PROXY));
assertThrows(UntrustedProxyException.class, () -> x509AuthenticationProvider.authenticate(getX509Request(buildProxyChain(IDENTITY_1), UNTRUSTED_PROXY)));
}
@Test
@ -155,18 +157,9 @@ public class X509AuthenticationProviderTest {
assertFalse(user.getChain().isAnonymous());
}
@Test(expected = InvalidAuthenticationException.class)
@Test
public void testAnonymousWithOneProxyWhileAnonymousAuthenticationPrevented() {
final NiFiAuthenticationToken auth = (NiFiAuthenticationToken) x509AuthenticationProvider.authenticate(getX509Request(buildProxyChain(ANONYMOUS), PROXY_1));
final NiFiUser user = ((NiFiUserDetails) auth.getDetails()).getNiFiUser();
assertNotNull(user);
assertEquals(StandardNiFiUser.ANONYMOUS_IDENTITY, user.getIdentity());
assertTrue(user.isAnonymous());
assertNotNull(user.getChain());
assertEquals(PROXY_1, user.getChain().getIdentity());
assertFalse(user.getChain().isAnonymous());
assertThrows(InvalidAuthenticationException.class, () -> x509AuthenticationProvider.authenticate(getX509Request(buildProxyChain(ANONYMOUS), PROXY_1)));
}
@Test
@ -187,9 +180,9 @@ public class X509AuthenticationProviderTest {
assertFalse(user.getChain().getChain().isAnonymous());
}
@Test(expected = UntrustedProxyException.class)
@Test
public void testUntrustedProxyInChain() {
x509AuthenticationProvider.authenticate(getX509Request(buildProxyChain(IDENTITY_1, UNTRUSTED_PROXY), PROXY_1));
assertThrows(UntrustedProxyException.class, () -> x509AuthenticationProvider.authenticate(getX509Request(buildProxyChain(IDENTITY_1, UNTRUSTED_PROXY), PROXY_1)));
}
@Test
@ -217,22 +210,9 @@ public class X509AuthenticationProviderTest {
assertFalse(user.getChain().getChain().isAnonymous());
}
@Test(expected = InvalidAuthenticationException.class)
@Test
public void testAnonymousProxyInChainWhileAnonymousAuthenticationPrevented() {
final NiFiAuthenticationToken auth = (NiFiAuthenticationToken) x509AuthenticationProvider.authenticate(getX509Request(buildProxyChain(IDENTITY_1, ANONYMOUS), PROXY_1));
final NiFiUser user = ((NiFiUserDetails) auth.getDetails()).getNiFiUser();
assertNotNull(user);
assertEquals(IDENTITY_1, user.getIdentity());
assertFalse(user.isAnonymous());
assertNotNull(user.getChain());
assertEquals(StandardNiFiUser.ANONYMOUS_IDENTITY, user.getChain().getIdentity());
assertTrue(user.getChain().isAnonymous());
assertNotNull(user.getChain().getChain());
assertEquals(PROXY_1, user.getChain().getChain().getIdentity());
assertFalse(user.getChain().getChain().isAnonymous());
assertThrows(InvalidAuthenticationException.class, () -> x509AuthenticationProvider.authenticate(getX509Request(buildProxyChain(IDENTITY_1, ANONYMOUS), PROXY_1)));
}
@Test
@ -244,10 +224,9 @@ public class X509AuthenticationProviderTest {
NiFiUser user = X509AuthenticationProvider.createUser(identity, null, null, null, null, true);
// Assert
assert user != null;
assert user instanceof StandardNiFiUser;
assert user.getIdentity().equals(StandardNiFiUser.ANONYMOUS_IDENTITY);
assert user.isAnonymous();
assertInstanceOf(StandardNiFiUser.class, user);
assertEquals(StandardNiFiUser.ANONYMOUS_IDENTITY, user.getIdentity());
assertTrue(user.isAnonymous());
}
@Test
@ -259,10 +238,9 @@ public class X509AuthenticationProviderTest {
NiFiUser user = X509AuthenticationProvider.createUser(identity, null, null, null, null, false);
// Assert
assert user != null;
assert user instanceof StandardNiFiUser;
assert user.getIdentity().equals(identity);
assert !user.isAnonymous();
assertInstanceOf(StandardNiFiUser.class, user);
assertEquals(identity, user.getIdentity());
assertFalse(user.isAnonymous());
}
private String buildProxyChain(final String... identities) {

View File

@ -29,19 +29,16 @@ import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.operator.ContentSigner;
import org.bouncycastle.operator.OperatorCreationException;
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.math.BigInteger;
import java.security.InvalidKeyException;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.Security;
@ -51,18 +48,22 @@ import java.security.cert.X509Certificate;
import java.util.Date;
import java.util.Vector;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class OcspCertificateValidatorTest {
private static final Logger logger = LoggerFactory.getLogger(OcspCertificateValidatorTest.class);
private static final int KEY_SIZE = 2048;
private static final long YESTERDAY = System.currentTimeMillis() - 24 * 60 * 60 * 1000;
private static final long ONE_YEAR_FROM_NOW = System.currentTimeMillis() + 365 * 24 * 60 * 60 * 1000;
private static final long ONE_YEAR_FROM_NOW = System.currentTimeMillis() + 365L * 24 * 60 * 60 * 1000;
private static final String SIGNATURE_ALGORITHM = "SHA256withRSA";
private static final String PROVIDER = "BC";
@BeforeClass
public static void setUpOnce() throws Exception {
@BeforeAll
public static void setUpOnce() {
Security.addProvider(new BouncyCastleProvider());
}
@ -86,13 +87,10 @@ public class OcspCertificateValidatorTest {
* @throws IOException if an exception occurs
* @throws NoSuchAlgorithmException if an exception occurs
* @throws CertificateException if an exception occurs
* @throws NoSuchProviderException if an exception occurs
* @throws SignatureException if an exception occurs
* @throws InvalidKeyException if an exception occurs
* @throws OperatorCreationException if an exception occurs
*/
private static X509Certificate generateCertificate(String dn) throws IOException, NoSuchAlgorithmException, CertificateException, NoSuchProviderException, SignatureException,
InvalidKeyException, OperatorCreationException {
private static X509Certificate generateCertificate(String dn) throws IOException, NoSuchAlgorithmException, CertificateException,
OperatorCreationException {
KeyPair keyPair = generateKeyPair();
return generateCertificate(dn, keyPair);
}
@ -104,15 +102,11 @@ public class OcspCertificateValidatorTest {
* @param keyPair the public key will be included in the certificate and the the private key is used to sign the certificate
* @return the certificate
* @throws IOException if an exception occurs
* @throws NoSuchAlgorithmException if an exception occurs
* @throws CertificateException if an exception occurs
* @throws NoSuchProviderException if an exception occurs
* @throws SignatureException if an exception occurs
* @throws InvalidKeyException if an exception occurs
* @throws OperatorCreationException if an exception occurs
*/
private static X509Certificate generateCertificate(String dn, KeyPair keyPair) throws IOException, NoSuchAlgorithmException, CertificateException, NoSuchProviderException, SignatureException,
InvalidKeyException, OperatorCreationException {
private static X509Certificate generateCertificate(String dn, KeyPair keyPair) throws IOException, CertificateException,
OperatorCreationException {
PrivateKey privateKey = keyPair.getPrivate();
ContentSigner sigGen = new JcaContentSignerBuilder(SIGNATURE_ALGORITHM).setProvider(PROVIDER).build(privateKey);
SubjectPublicKeyInfo subPubKeyInfo = SubjectPublicKeyInfo.getInstance(keyPair.getPublic().getEncoded());
@ -150,16 +144,12 @@ public class OcspCertificateValidatorTest {
* @param issuerDn the issuer DN
* @param issuerKey the issuer private key
* @return the certificate
* @throws IOException if an exception occurs
* @throws NoSuchAlgorithmException if an exception occurs
* @throws CertificateException if an exception occurs
* @throws NoSuchProviderException if an exception occurs
* @throws SignatureException if an exception occurs
* @throws InvalidKeyException if an exception occurs
* @throws OperatorCreationException if an exception occurs
*/
private static X509Certificate generateIssuedCertificate(String dn, String issuerDn, PrivateKey issuerKey) throws IOException, NoSuchAlgorithmException, CertificateException,
NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
private static X509Certificate generateIssuedCertificate(String dn, String issuerDn, PrivateKey issuerKey) throws NoSuchAlgorithmException, CertificateException,
OperatorCreationException {
KeyPair keyPair = generateKeyPair();
return generateIssuedCertificate(dn, keyPair.getPublic(), issuerDn, issuerKey);
}
@ -172,16 +162,11 @@ public class OcspCertificateValidatorTest {
* @param issuerDn the issuer DN
* @param issuerKey the issuer private key
* @return the certificate
* @throws IOException if an exception occurs
* @throws NoSuchAlgorithmException if an exception occurs
* @throws CertificateException if an exception occurs
* @throws NoSuchProviderException if an exception occurs
* @throws SignatureException if an exception occurs
* @throws InvalidKeyException if an exception occurs
* @throws OperatorCreationException if an exception occurs
*/
private static X509Certificate generateIssuedCertificate(String dn, PublicKey publicKey, String issuerDn, PrivateKey issuerKey) throws IOException, NoSuchAlgorithmException,
CertificateException, NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
private static X509Certificate generateIssuedCertificate(String dn, PublicKey publicKey, String issuerDn, PrivateKey issuerKey) throws
CertificateException, OperatorCreationException {
ContentSigner sigGen = new JcaContentSignerBuilder(SIGNATURE_ALGORITHM).setProvider(PROVIDER).build(issuerKey);
SubjectPublicKeyInfo subPubKeyInfo = SubjectPublicKeyInfo.getInstance(publicKey.getEncoded());
Date startDate = new Date(YESTERDAY);
@ -209,8 +194,8 @@ public class OcspCertificateValidatorTest {
logger.info("Generated certificate: \n{}", certificate);
// Assert
assert certificate.getSubjectDN().getName().equals(testDn);
assert certificate.getIssuerDN().getName().equals(testDn);
assertEquals(testDn, certificate.getSubjectDN().getName());
assertEquals(testDn, certificate.getIssuerDN().getName());
certificate.verify(certificate.getPublicKey());
}
@ -225,9 +210,9 @@ public class OcspCertificateValidatorTest {
logger.info("Generated certificate: \n{}", certificate);
// Assert
assert certificate.getPublicKey().equals(keyPair.getPublic());
assert certificate.getSubjectDN().getName().equals(testDn);
assert certificate.getIssuerDN().getName().equals(testDn);
assertEquals(keyPair.getPublic(), certificate.getPublicKey());
assertEquals(testDn, certificate.getSubjectDN().getName());
assertEquals(testDn, certificate.getIssuerDN().getName());
certificate.verify(certificate.getPublicKey());
}
@ -247,17 +232,12 @@ public class OcspCertificateValidatorTest {
logger.info("Generated signed certificate: \n{}", certificate);
// Assert
assert issuerCertificate.getPublicKey().equals(issuerKeyPair.getPublic());
assert certificate.getSubjectX500Principal().getName().equals(testDn);
assert certificate.getIssuerX500Principal().getName().equals(issuerDn);
assertEquals(issuerKeyPair.getPublic(), issuerCertificate.getPublicKey());
assertEquals(testDn, certificate.getSubjectX500Principal().getName());
assertEquals(issuerDn, certificate.getIssuerX500Principal().getName());
certificate.verify(issuerCertificate.getPublicKey());
try {
certificate.verify(certificate.getPublicKey());
Assert.fail("Should have thrown exception");
} catch (Exception e) {
assert e instanceof SignatureException;
assert e.getMessage().contains("certificate does not verify with supplied key");
}
SignatureException se = assertThrows(SignatureException.class, () -> certificate.verify(certificate.getPublicKey()));
assertTrue(se.getMessage().contains("certificate does not verify with supplied key"));
}
}