SOLR-14420 Declare ServletRequests as HttpRequests in AuthenticationPlugin (#1442)

Declare ServletRequests as HttpRequests in AuthenticationPlugin
This commit is contained in:
Mike Drob 2020-04-22 12:06:18 -05:00 committed by GitHub
parent 5d60ff4613
commit fe05a6d380
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 56 additions and 92 deletions

View File

@ -60,6 +60,8 @@ Other Changes
* SOLR-9909: The deprecated SolrjNamedThreadFactory has been removed. Use SolrNamedThreadFactory instead. * SOLR-9909: The deprecated SolrjNamedThreadFactory has been removed. Use SolrNamedThreadFactory instead.
(Andras Salamon, shalin) (Andras Salamon, shalin)
* SOLR-14420: AuthenticationPlugin.authenticate accepts HttpServletRequest instead of ServletRequest. (Mike Drob)
================== 8.6.0 ================== ================== 8.6.0 ==================
Consult the LUCENE_CHANGES.txt file for additional, low level, changes in this release. Consult the LUCENE_CHANGES.txt file for additional, low level, changes in this release.

View File

@ -17,8 +17,10 @@
package org.apache.solr.security; package org.apache.solr.security;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import java.security.Principal;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
@ -72,15 +74,14 @@ public abstract class AuthenticationPlugin implements SolrInfoBean {
* the response and status code have already been sent. * the response and status code have already been sent.
* @throws Exception any exception thrown during the authentication, e.g. PrivilegedActionException * @throws Exception any exception thrown during the authentication, e.g. PrivilegedActionException
*/ */
//TODO redeclare params as HttpServletRequest & HttpServletResponse public abstract boolean doAuthenticate(HttpServletRequest request, HttpServletResponse response,
public abstract boolean doAuthenticate(ServletRequest request, ServletResponse response,
FilterChain filterChain) throws Exception; FilterChain filterChain) throws Exception;
/** /**
* This method is called by SolrDispatchFilter in order to initiate authentication. * This method is called by SolrDispatchFilter in order to initiate authentication.
* It does some standard metrics counting. * It does some standard metrics counting.
*/ */
public final boolean authenticate(ServletRequest request, ServletResponse response, FilterChain filterChain) throws Exception { public final boolean authenticate(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws Exception {
Timer.Context timer = requestTimes.time(); Timer.Context timer = requestTimes.time();
requests.inc(); requests.inc();
try { try {
@ -94,6 +95,15 @@ public abstract class AuthenticationPlugin implements SolrInfoBean {
} }
} }
HttpServletRequest wrapWithPrincipal(HttpServletRequest request, Principal principal) {
return new HttpServletRequestWrapper(request) {
@Override
public Principal getUserPrincipal() {
return principal;
}
};
}
/** /**
* Override this method to intercept internode requests. This allows your authentication * Override this method to intercept internode requests. This allows your authentication
* plugin to decide on per-request basis whether it should handle inter-node requests or * plugin to decide on per-request basis whether it should handle inter-node requests or

View File

@ -18,10 +18,7 @@ package org.apache.solr.security;
import javax.security.auth.Subject; import javax.security.auth.Subject;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
@ -124,11 +121,7 @@ public class BasicAuthPlugin extends AuthenticationPlugin implements ConfigEdita
} }
@Override @Override
public boolean doAuthenticate(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws Exception { public boolean doAuthenticate(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws Exception {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
String authHeader = request.getHeader("Authorization"); String authHeader = request.getHeader("Authorization");
boolean isAjaxRequest = isAjaxRequest(request); boolean isAjaxRequest = isAjaxRequest(request);
@ -151,14 +144,10 @@ public class BasicAuthPlugin extends AuthenticationPlugin implements ConfigEdita
authenticationFailure(response, isAjaxRequest, "Bad credentials"); authenticationFailure(response, isAjaxRequest, "Bad credentials");
return false; return false;
} else { } else {
HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request) { Principal principal = new BasicAuthUserPrincipal(username, pwd);
@Override request = wrapWithPrincipal(request, principal);
public Principal getUserPrincipal() {
return new BasicAuthUserPrincipal(username, pwd);
}
};
numAuthenticated.inc(); numAuthenticated.inc();
filterChain.doFilter(wrapper, response); filterChain.doFilter(request, response);
return true; return true;
} }
} else { } else {

View File

@ -38,7 +38,6 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonGenerator;
import org.apache.commons.collections.iterators.IteratorEnumeration;
import org.apache.hadoop.security.authentication.server.AuthenticationFilter; import org.apache.hadoop.security.authentication.server.AuthenticationFilter;
import org.apache.hadoop.security.token.delegation.web.DelegationTokenAuthenticationHandler; import org.apache.hadoop.security.token.delegation.web.DelegationTokenAuthenticationHandler;
import org.apache.solr.client.solrj.impl.Krb5HttpClientBuilder; import org.apache.solr.client.solrj.impl.Krb5HttpClientBuilder;
@ -212,7 +211,7 @@ public class HadoopAuthPlugin extends AuthenticationPlugin {
@Override @Override
public Enumeration<String> getInitParameterNames() { public Enumeration<String> getInitParameterNames() {
return new IteratorEnumeration(params.keySet().iterator()); return Collections.enumeration(params.keySet());
} }
@Override @Override
@ -230,24 +229,21 @@ public class HadoopAuthPlugin extends AuthenticationPlugin {
} }
@Override @Override
public boolean doAuthenticate(ServletRequest request, ServletResponse response, FilterChain filterChain) public boolean doAuthenticate(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws Exception { throws Exception {
final HttpServletResponse frsp = (HttpServletResponse)response;
if (TRACE_HTTP) { if (TRACE_HTTP) {
HttpServletRequest req = (HttpServletRequest) request;
log.info("----------HTTP Request---------{}"); log.info("----------HTTP Request---------{}");
if (log.isInfoEnabled()) { if (log.isInfoEnabled()) {
log.info("{} : {}", req.getMethod(), req.getRequestURI()); log.info("{} : {}", request.getMethod(), request.getRequestURI());
} }
if (log.isInfoEnabled()) { if (log.isInfoEnabled()) {
log.info("Query : {}", req.getQueryString()); log.info("Query : {}", request.getQueryString());
} }
log.info("Headers :"); log.info("Headers :");
Enumeration<String> headers = req.getHeaderNames(); Enumeration<String> headers = request.getHeaderNames();
while (headers.hasMoreElements()) { while (headers.hasMoreElements()) {
String name = headers.nextElement(); String name = headers.nextElement();
Enumeration<String> hvals = req.getHeaders(name); Enumeration<String> hvals = request.getHeaders(name);
while (hvals.hasMoreElements()) { while (hvals.hasMoreElements()) {
if (log.isInfoEnabled()) { if (log.isInfoEnabled()) {
log.info("{} : {}", name, hvals.nextElement()); log.info("{} : {}", name, hvals.nextElement());
@ -257,9 +253,9 @@ public class HadoopAuthPlugin extends AuthenticationPlugin {
log.info("-------------------------------"); log.info("-------------------------------");
} }
authFilter.doFilter(request, frsp, filterChain); authFilter.doFilter(request, response, filterChain);
switch (frsp.getStatus()) { switch (response.getStatus()) {
case HttpServletResponse.SC_UNAUTHORIZED: case HttpServletResponse.SC_UNAUTHORIZED:
// Cannot tell whether the 401 is due to wrong or missing credentials // Cannot tell whether the 401 is due to wrong or missing credentials
numWrongCredentials.inc(); numWrongCredentials.inc();
@ -270,7 +266,7 @@ public class HadoopAuthPlugin extends AuthenticationPlugin {
numErrors.mark(); numErrors.mark();
break; break;
default: default:
if (frsp.getStatus() >= 200 && frsp.getStatus() <= 299) { if (response.getStatus() >= 200 && response.getStatus() <= 299) {
numAuthenticated.inc(); numAuthenticated.inc();
} else { } else {
numErrors.mark(); numErrors.mark();
@ -280,11 +276,11 @@ public class HadoopAuthPlugin extends AuthenticationPlugin {
if (TRACE_HTTP) { if (TRACE_HTTP) {
log.info("----------HTTP Response---------"); log.info("----------HTTP Response---------");
if (log.isInfoEnabled()) { if (log.isInfoEnabled()) {
log.info("Status : {}", frsp.getStatus()); log.info("Status : {}", response.getStatus());
} }
log.info("Headers :"); log.info("Headers :");
for (String name : frsp.getHeaderNames()) { for (String name : response.getHeaderNames()) {
for (String value : frsp.getHeaders(name)) { for (String value : response.getHeaders(name)) {
log.info("{} : {}", name, value); log.info("{} : {}", name, value);
} }
} }

View File

@ -17,10 +17,7 @@
package org.apache.solr.security; package org.apache.solr.security;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.io.IOException; import java.io.IOException;
import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandles;
@ -274,10 +271,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
* Main authentication method that looks for correct JWT token in the Authorization header * Main authentication method that looks for correct JWT token in the Authorization header
*/ */
@Override @Override
public boolean doAuthenticate(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws Exception { public boolean doAuthenticate(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws Exception {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
String header = request.getHeader(HttpHeaders.AUTHORIZATION); String header = request.getHeader(HttpHeaders.AUTHORIZATION);
if (jwtConsumer == null) { if (jwtConsumer == null) {
@ -320,12 +314,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
switch (authResponse.getAuthCode()) { switch (authResponse.getAuthCode()) {
case AUTHENTICATED: case AUTHENTICATED:
final Principal principal = authResponse.getPrincipal(); final Principal principal = authResponse.getPrincipal();
HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request) { request = wrapWithPrincipal(request, principal);
@Override
public Principal getUserPrincipal() {
return principal;
}
};
if (!(principal instanceof JWTPrincipal)) { if (!(principal instanceof JWTPrincipal)) {
numErrors.mark(); numErrors.mark();
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "JWTAuth plugin says AUTHENTICATED but no token extracted"); throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "JWTAuth plugin says AUTHENTICATED but no token extracted");
@ -333,7 +322,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
if (log.isDebugEnabled()) if (log.isDebugEnabled())
log.debug("Authentication SUCCESS"); log.debug("Authentication SUCCESS");
numAuthenticated.inc(); numAuthenticated.inc();
filterChain.doFilter(wrapper, response); filterChain.doFilter(request, response);
return true; return true;
case PASS_THROUGH: case PASS_THROUGH:

View File

@ -17,6 +17,7 @@
package org.apache.solr.security; package org.apache.solr.security;
import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandles;
import java.util.Collections;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -26,12 +27,11 @@ import javax.servlet.FilterChain;
import javax.servlet.FilterConfig; import javax.servlet.FilterConfig;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletResponse;
import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonGenerator;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.collections.iterators.IteratorEnumeration;
import org.apache.hadoop.security.token.delegation.web.DelegationTokenAuthenticationHandler; import org.apache.hadoop.security.token.delegation.web.DelegationTokenAuthenticationHandler;
import org.apache.http.HttpRequest; import org.apache.http.HttpRequest;
import org.apache.http.protocol.HttpContext; import org.apache.http.protocol.HttpContext;
@ -195,7 +195,7 @@ public class KerberosPlugin extends AuthenticationPlugin implements HttpClientBu
@Override @Override
public Enumeration<String> getInitParameterNames() { public Enumeration<String> getInitParameterNames() {
return new IteratorEnumeration(params.keySet().iterator()); return Collections.enumeration(params.keySet());
} }
@Override @Override
@ -228,7 +228,7 @@ public class KerberosPlugin extends AuthenticationPlugin implements HttpClientBu
} }
@Override @Override
public boolean doAuthenticate(ServletRequest req, ServletResponse rsp, public boolean doAuthenticate(HttpServletRequest req, HttpServletResponse rsp,
FilterChain chain) throws Exception { FilterChain chain) throws Exception {
log.debug("Request to authenticate using kerberos: {}", req); log.debug("Request to authenticate using kerberos: {}", req);
kerberosFilter.doFilter(req, rsp, chain); kerberosFilter.doFilter(req, rsp, chain);

View File

@ -17,10 +17,8 @@
package org.apache.solr.security; package org.apache.solr.security;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse;
import java.io.IOException; import java.io.IOException;
import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandles;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -85,16 +83,16 @@ public class PKIAuthenticationPlugin extends AuthenticationPlugin implements Htt
@SuppressForbidden(reason = "Needs currentTimeMillis to compare against time in header") @SuppressForbidden(reason = "Needs currentTimeMillis to compare against time in header")
@Override @Override
public boolean doAuthenticate(ServletRequest request, ServletResponse response, FilterChain filterChain) throws Exception { public boolean doAuthenticate(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws Exception {
String requestURI = ((HttpServletRequest) request).getRequestURI(); String requestURI = request.getRequestURI();
if (requestURI.endsWith(PublicKeyHandler.PATH)) { if (requestURI.endsWith(PublicKeyHandler.PATH)) {
numPassThrough.inc(); numPassThrough.inc();
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
return true; return true;
} }
long receivedTime = System.currentTimeMillis(); long receivedTime = System.currentTimeMillis();
String header = ((HttpServletRequest) request).getHeader(HEADER); String header = request.getHeader(HEADER);
if (header == null) { if (header == null) {
//this must not happen //this must not happen
log.error("No SolrAuth header present"); log.error("No SolrAuth header present");
@ -133,19 +131,10 @@ public class PKIAuthenticationPlugin extends AuthenticationPlugin implements Htt
new BasicUserPrincipal(decipher.userName); new BasicUserPrincipal(decipher.userName);
numAuthenticated.inc(); numAuthenticated.inc();
filterChain.doFilter(getWrapper((HttpServletRequest) request, principal), response); filterChain.doFilter(wrapWithPrincipal(request, principal), response);
return true; return true;
} }
private static HttpServletRequestWrapper getWrapper(final HttpServletRequest request, final Principal principal) {
return new HttpServletRequestWrapper(request) {
@Override
public Principal getUserPrincipal() {
return principal;
}
};
}
public static class PKIHeaderData { public static class PKIHeaderData {
String userName; String userName;
long timestamp; long timestamp;

View File

@ -17,8 +17,6 @@
package org.apache.solr.cloud; package org.apache.solr.cloud;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandles;
@ -141,7 +139,7 @@ public class TestAuthenticationFramework extends SolrCloudTestCase {
public void init(Map<String,Object> pluginConfig) {} public void init(Map<String,Object> pluginConfig) {}
@Override @Override
public boolean doAuthenticate(ServletRequest request, ServletResponse response, FilterChain filterChain) public boolean doAuthenticate(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws Exception { throws Exception {
if (expectedUsername == null) { if (expectedUsername == null) {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);

View File

@ -21,7 +21,7 @@ import javax.servlet.ServletException;
import javax.servlet.ServletRequest; import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse; import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse;
import java.io.IOException; import java.io.IOException;
import java.security.Principal; import java.security.Principal;
import java.util.Map; import java.util.Map;
@ -38,7 +38,7 @@ public class MockAuthenticationPlugin extends AuthenticationPlugin {
} }
@Override @Override
public boolean doAuthenticate(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException { public boolean doAuthenticate(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException {
String user = null; String user = null;
if (predicate != null) { if (predicate != null) {
if (predicate.test(request)) { if (predicate.test(request)) {
@ -47,28 +47,19 @@ public class MockAuthenticationPlugin extends AuthenticationPlugin {
} }
} }
final FilterChain ffc = filterChain;
final AtomicBoolean requestContinues = new AtomicBoolean(false); final AtomicBoolean requestContinues = new AtomicBoolean(false);
forward(user, request, response, new FilterChain() { forward(user, request, response, (req, res) -> {
@Override filterChain.doFilter(req, res);
public void doFilter(ServletRequest req, ServletResponse res) throws IOException, ServletException {
ffc.doFilter(req, res);
requestContinues.set(true); requestContinues.set(true);
}
}); });
return requestContinues.get(); return requestContinues.get();
} }
protected void forward(String user, ServletRequest req, ServletResponse rsp, protected void forward(String user, HttpServletRequest req, ServletResponse rsp,
FilterChain chain) throws IOException, ServletException { FilterChain chain) throws IOException, ServletException {
if(user != null) { if(user != null) {
final Principal p = new BasicUserPrincipal(user); final Principal p = new BasicUserPrincipal(user);
req = new HttpServletRequestWrapper((HttpServletRequest) req) { req = wrapWithPrincipal(req, p);
@Override
public Principal getUserPrincipal() {
return p;
}
};
} }
chain.doFilter(req, rsp); chain.doFilter(req, rsp);
} }