Improve exception handling in extension druid-pac4j (#16979)

Changes:
- Simplify exception handling in `CryptoService` by just catching a `Exception`
- Throw a `DruidException` as the exception is user facing
- Log the exception for easier debugging
- Add a test to verify thrown exception
This commit is contained in:
Kashif Faraz 2024-08-30 00:02:49 -07:00 committed by GitHub
parent 358d06abc1
commit d5b64ba2e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 13 deletions

View File

@ -19,6 +19,7 @@
package org.apache.druid.security.pac4j;
import org.apache.druid.error.DruidException;
import org.easymock.Capture;
import org.easymock.EasyMock;
import org.junit.Assert;
@ -41,7 +42,7 @@ public class Pac4jSessionStoreTest
@Test
public void testSetAndGet()
{
Pac4jSessionStore<WebContext> sessionStore = new Pac4jSessionStore(COOKIE_PASSPHRASE);
Pac4jSessionStore<WebContext> sessionStore = new Pac4jSessionStore<>(COOKIE_PASSPHRASE);
WebContext webContext1 = EasyMock.mock(WebContext.class);
EasyMock.expect(webContext1.getScheme()).andReturn("https");
@ -68,7 +69,7 @@ public class Pac4jSessionStoreTest
@Test
public void testSetAndGetClearUserProfile()
{
Pac4jSessionStore<WebContext> sessionStore = new Pac4jSessionStore(COOKIE_PASSPHRASE);
Pac4jSessionStore<WebContext> sessionStore = new Pac4jSessionStore<>(COOKIE_PASSPHRASE);
WebContext webContext1 = EasyMock.mock(WebContext.class);
EasyMock.expect(webContext1.getScheme()).andReturn("https");
@ -78,6 +79,7 @@ public class Pac4jSessionStoreTest
EasyMock.replay(webContext1);
CommonProfile profile = new CommonProfile();
profile.setId("profile1");
profile.addAttribute(CommonProfileDefinition.DISPLAY_NAME, "name");
sessionStore.set(webContext1, "pac4jUserProfiles", profile);
@ -99,7 +101,7 @@ public class Pac4jSessionStoreTest
@Test
public void testSetAndGetClearUserMultipleProfile()
{
Pac4jSessionStore<WebContext> sessionStore = new Pac4jSessionStore(COOKIE_PASSPHRASE);
Pac4jSessionStore<WebContext> sessionStore = new Pac4jSessionStore<>(COOKIE_PASSPHRASE);
WebContext webContext1 = EasyMock.mock(WebContext.class);
EasyMock.expect(webContext1.getScheme()).andReturn("https");
@ -109,8 +111,10 @@ public class Pac4jSessionStoreTest
EasyMock.replay(webContext1);
CommonProfile profile1 = new CommonProfile();
profile1.setId("profile1");
profile1.addAttribute(CommonProfileDefinition.DISPLAY_NAME, "name1");
CommonProfile profile2 = new CommonProfile();
profile2.setId("profile2");
profile2.addAttribute(CommonProfileDefinition.DISPLAY_NAME, "name2");
Map<String, CommonProfile> profiles = new HashMap<>();
profiles.put("profile1", profile1);
@ -131,4 +135,36 @@ public class Pac4jSessionStoreTest
Assert.assertTrue(Objects.requireNonNull(value).isPresent());
Assert.assertEquals(2, ((Map<String, CommonProfile>) value.get()).size());
}
@Test
public void testGetWithWrongPassphraseThrowsDruidException()
{
final WebContext webContext = EasyMock.mock(WebContext.class);
EasyMock.expect(webContext.getScheme()).andReturn("https");
final Capture<Cookie> cookieCapture = EasyMock.newCapture();
EasyMock.expect(webContext.getRequestCookies())
.andAnswer(() -> Collections.singleton(cookieCapture.getValue()));
webContext.addResponseCookie(EasyMock.capture(cookieCapture));
EasyMock.expectLastCall().once();
EasyMock.replay(webContext);
// Create a cookie with an invalid passphrase
new Pac4jSessionStore<>("invalid-passphrase").set(webContext, "key", "value");
// Verify that trying to decrypt the invalid cookie throws an exception
final Pac4jSessionStore<WebContext> sessionStore = new Pac4jSessionStore<>(COOKIE_PASSPHRASE);
DruidException exception = Assert.assertThrows(
DruidException.class,
() -> sessionStore.get(webContext, "key")
);
Assert.assertEquals(
"Decryption failed. Check service logs.",
exception.getMessage()
);
Assert.assertNull(exception.getCause());
EasyMock.verify(webContext);
}
}

View File

@ -20,25 +20,21 @@
package org.apache.druid.crypto;
import com.google.common.base.Preconditions;
import org.apache.druid.error.InternalServerError;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import javax.annotation.Nullable;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import java.nio.ByteBuffer;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.InvalidParameterSpecException;
import java.security.spec.KeySpec;
/**
@ -50,6 +46,8 @@ import java.security.spec.KeySpec;
*/
public class CryptoService
{
private static final Logger log = new Logger(CryptoService.class);
// Based on Javadocs on SecureRandom, It is threadsafe as well.
private static final SecureRandom SECURE_RANDOM_INSTANCE = new SecureRandom();
@ -125,8 +123,9 @@ public class CryptoService
ecipher.doFinal(plain)
).toByteAray();
}
catch (InvalidKeySpecException | NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | InvalidParameterSpecException | IllegalBlockSizeException | BadPaddingException ex) {
throw new RuntimeException(ex);
catch (Exception ex) {
log.noStackTrace().warn(ex, "Encryption failed");
throw InternalServerError.exception("Encryption failed. Check service logs.");
}
}
@ -145,8 +144,9 @@ public class CryptoService
dcipher.init(Cipher.DECRYPT_MODE, secret, new IvParameterSpec(encryptedData.getIv()));
return dcipher.doFinal(encryptedData.getCipher());
}
catch (InvalidKeySpecException | NoSuchAlgorithmException | InvalidAlgorithmParameterException | NoSuchPaddingException | InvalidKeyException | IllegalBlockSizeException | BadPaddingException ex) {
throw new RuntimeException(ex);
catch (Exception ex) {
log.noStackTrace().warn(ex, "Decryption failed");
throw InternalServerError.exception("Decryption failed. Check service logs.");
}
}