Resolve 6173 - Log unhandled Exceptions in RestfulServer (#6176)

* 6173 - Log unhandled Exceptions in RestfulServer.

* Use placeholder for failed streams.

* Starting test for server handling.

* Got test working.

* Fixed use of synchronized keyword.

* Applied mvn spotless.

---------

Co-authored-by: Michael Buckley <michaelabuckley@gmail.com>
This commit is contained in:
Kevin Dougan 2024-08-07 10:01:58 -04:00 committed by GitHub
parent 801b8b494b
commit 44c4b87a84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 90 additions and 7 deletions

View File

@ -20,6 +20,7 @@
package ca.uhn.fhir.rest.api.server; package ca.uhn.fhir.rest.api.server;
import ca.uhn.fhir.context.FhirContext; import ca.uhn.fhir.context.FhirContext;
import ca.uhn.fhir.i18n.Msg;
import ca.uhn.fhir.interceptor.api.IInterceptorBroadcaster; import ca.uhn.fhir.interceptor.api.IInterceptorBroadcaster;
import ca.uhn.fhir.rest.api.Constants; import ca.uhn.fhir.rest.api.Constants;
import ca.uhn.fhir.rest.api.RequestTypeEnum; import ca.uhn.fhir.rest.api.RequestTypeEnum;
@ -30,6 +31,7 @@ import ca.uhn.fhir.util.StopWatch;
import ca.uhn.fhir.util.UrlUtil; import ca.uhn.fhir.util.UrlUtil;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.Validate; import org.apache.commons.lang3.Validate;
import org.hl7.fhir.instance.model.api.IBaseResource; import org.hl7.fhir.instance.model.api.IBaseResource;
import org.hl7.fhir.instance.model.api.IIdType; import org.hl7.fhir.instance.model.api.IIdType;
@ -39,6 +41,7 @@ import java.io.InputStream;
import java.io.Reader; import java.io.Reader;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -50,6 +53,8 @@ import static org.apache.commons.lang3.StringUtils.isBlank;
public abstract class RequestDetails { public abstract class RequestDetails {
public static final byte[] BAD_STREAM_PLACEHOLDER =
(Msg.code(2543) + "PLACEHOLDER WHEN READING FROM BAD STREAM").getBytes(StandardCharsets.UTF_8);
private final StopWatch myRequestStopwatch; private final StopWatch myRequestStopwatch;
private IInterceptorBroadcaster myInterceptorBroadcaster; private IInterceptorBroadcaster myInterceptorBroadcaster;
private String myTenantId; private String myTenantId;
@ -523,9 +528,22 @@ public abstract class RequestDetails {
mySubRequest = theSubRequest; mySubRequest = theSubRequest;
} }
public final byte[] loadRequestContents() { public final synchronized byte[] loadRequestContents() {
if (myRequestContents == null) { if (myRequestContents == null) {
myRequestContents = getByteStreamRequestContents(); // Initialize the byte array to a non-null value to avoid repeated calls to getByteStreamRequestContents()
// which can occur when getByteStreamRequestContents() throws an Exception
myRequestContents = ArrayUtils.EMPTY_BYTE_ARRAY;
try {
myRequestContents = getByteStreamRequestContents();
} finally {
if (myRequestContents == null) {
// if reading the stream throws an exception, then our contents are still null, but the stream is
// dead.
// Set a placeholder value so nobody tries to read again.
myRequestContents = BAD_STREAM_PLACEHOLDER;
}
}
assert myRequestContents != null : "We must not re-read the stream.";
} }
return getRequestContentsIfLoaded(); return getRequestContentsIfLoaded();
} }

View File

@ -1038,6 +1038,9 @@ public class RestfulServer extends HttpServlet implements IRestfulServer<Servlet
theRequest.setAttribute(SERVLET_CONTEXT_ATTRIBUTE, getServletContext()); theRequest.setAttribute(SERVLET_CONTEXT_ATTRIBUTE, getServletContext());
// keep track of any unhandled exceptions in case the exception handler throws another exception
Throwable unhandledException = null;
try { try {
/* *********************************** /* ***********************************
@ -1205,6 +1208,8 @@ public class RestfulServer extends HttpServlet implements IRestfulServer<Servlet
} catch (NotModifiedException | AuthenticationException e) { } catch (NotModifiedException | AuthenticationException e) {
unhandledException = e;
HookParams handleExceptionParams = new HookParams(); HookParams handleExceptionParams = new HookParams();
handleExceptionParams.add(RequestDetails.class, requestDetails); handleExceptionParams.add(RequestDetails.class, requestDetails);
handleExceptionParams.add(ServletRequestDetails.class, requestDetails); handleExceptionParams.add(ServletRequestDetails.class, requestDetails);
@ -1216,9 +1221,12 @@ public class RestfulServer extends HttpServlet implements IRestfulServer<Servlet
} }
writeExceptionToResponse(theResponse, e); writeExceptionToResponse(theResponse, e);
unhandledException = null;
} catch (Throwable e) { } catch (Throwable e) {
unhandledException = e;
/* /*
* We have caught an exception during request processing. This might be because a handling method threw * We have caught an exception during request processing. This might be because a handling method threw
* something they wanted to throw (e.g. UnprocessableEntityException because the request * something they wanted to throw (e.g. UnprocessableEntityException because the request
@ -1285,9 +1293,18 @@ public class RestfulServer extends HttpServlet implements IRestfulServer<Servlet
* If nobody handles it, default behaviour is to stream back the OperationOutcome to the client. * If nobody handles it, default behaviour is to stream back the OperationOutcome to the client.
*/ */
DEFAULT_EXCEPTION_HANDLER.handleException(requestDetails, exception, theRequest, theResponse); DEFAULT_EXCEPTION_HANDLER.handleException(requestDetails, exception, theRequest, theResponse);
unhandledException = null;
} finally { } finally {
if (unhandledException != null) {
ourLog.error(
Msg.code(2544) + "Exception handling threw an exception. Initial exception was: {}",
unhandledException.getMessage(),
unhandledException);
unhandledException = null;
}
HookParams params = new HookParams(); HookParams params = new HookParams();
params.add(RequestDetails.class, requestDetails); params.add(RequestDetails.class, requestDetails);
params.addIfMatchesType(ServletRequestDetails.class, requestDetails); params.addIfMatchesType(ServletRequestDetails.class, requestDetails);

View File

@ -1,16 +1,22 @@
package ca.uhn.fhir.rest.server; package ca.uhn.fhir.rest.server;
import ca.uhn.fhir.context.FhirContext; import ca.uhn.fhir.context.FhirContext;
import ca.uhn.fhir.interceptor.api.Pointcut;
import ca.uhn.fhir.rest.annotation.Create; import ca.uhn.fhir.rest.annotation.Create;
import ca.uhn.fhir.rest.annotation.ResourceParam; import ca.uhn.fhir.rest.annotation.ResourceParam;
import ca.uhn.fhir.rest.annotation.Search;
import ca.uhn.fhir.rest.api.Constants; import ca.uhn.fhir.rest.api.Constants;
import ca.uhn.fhir.rest.api.MethodOutcome; import ca.uhn.fhir.rest.api.MethodOutcome;
import ca.uhn.fhir.rest.api.RequestTypeEnum; import ca.uhn.fhir.rest.api.RequestTypeEnum;
import ca.uhn.fhir.test.utilities.HttpClientExtension; import ca.uhn.fhir.test.utilities.HttpClientExtension;
import ca.uhn.fhir.test.utilities.server.RestfulServerExtension; import ca.uhn.fhir.test.utilities.server.RestfulServerExtension;
import ca.uhn.test.util.LogbackTestExtension;
import ca.uhn.test.util.LogbackTestExtensionAssert;
import com.helger.commons.collection.iterate.EmptyEnumeration; import com.helger.commons.collection.iterate.EmptyEnumeration;
import com.helger.commons.io.stream.StringInputStream;
import jakarta.annotation.Nonnull; import jakarta.annotation.Nonnull;
import jakarta.servlet.ReadListener; import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletInputStream; import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
@ -36,10 +42,12 @@ import java.io.PrintWriter;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Set; import java.util.Set;
import static org.apache.commons.collections.CollectionUtils.isEmpty; import static org.apache.commons.collections.CollectionUtils.isEmpty;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -52,6 +60,7 @@ public class ServerConcurrencyTest {
@RegisterExtension @RegisterExtension
private static final RestfulServerExtension ourServer = new RestfulServerExtension(ourCtx) private static final RestfulServerExtension ourServer = new RestfulServerExtension(ourCtx)
.registerProvider(new MyPatientProvider()); .registerProvider(new MyPatientProvider());
public static final String SEARCH_TIMEOUT_ERROR = "SEARCH_TIMEOUT_ERROR: Search timed out";
@RegisterExtension @RegisterExtension
private final HttpClientExtension myHttpClient = new HttpClientExtension(); private final HttpClientExtension myHttpClient = new HttpClientExtension();
@ -62,10 +71,12 @@ public class ServerConcurrencyTest {
@Mock @Mock
private PrintWriter myWriter; private PrintWriter myWriter;
private HashMap<String, String> myHeaders; private HashMap<String, String> myHeaders;
@RegisterExtension
LogbackTestExtension myLogbackTestExtension = new LogbackTestExtension();
@Test @Test
public void testExceptionClosingInputStream() throws IOException { public void testExceptionClosingInputStream() throws IOException {
initRequestMocks(); initRequestMocks("/Patient");
DelegatingServletInputStream inputStream = createMockPatientBodyServletInputStream(); DelegatingServletInputStream inputStream = createMockPatientBodyServletInputStream();
inputStream.setExceptionOnClose(true); inputStream.setExceptionOnClose(true);
when(myRequest.getInputStream()).thenReturn(inputStream); when(myRequest.getInputStream()).thenReturn(inputStream);
@ -78,7 +89,7 @@ public class ServerConcurrencyTest {
@Test @Test
public void testExceptionClosingOutputStream() throws IOException { public void testExceptionClosingOutputStream() throws IOException {
initRequestMocks(); initRequestMocks("/Patient");
when(myRequest.getInputStream()).thenReturn(createMockPatientBodyServletInputStream()); when(myRequest.getInputStream()).thenReturn(createMockPatientBodyServletInputStream());
when(myResponse.getWriter()).thenReturn(myWriter); when(myResponse.getWriter()).thenReturn(myWriter);
@ -90,12 +101,44 @@ public class ServerConcurrencyTest {
); );
} }
private void initRequestMocks() { /**
* Exception thrown during SERVER_OUTGOING_FAILURE_OPERATIONOUTCOME
*/
@Test
void testExceptionThrownDuringExceptionHandler_bothExceptionsLogged() throws ServletException, IOException {
// given
initRequestMocks("/Patient?active=true");
ourServer.getInterceptorService().registerAnonymousInterceptor(Pointcut.SERVER_OUTGOING_FAILURE_OPERATIONOUTCOME, (pointcut,params)->{
throw new RuntimeException("MARKER_2: Exception during exception processing");
});
// when
try {
ourServer.getRestfulServer().handleRequest(RequestTypeEnum.GET, myRequest, myResponse);
} catch (Throwable e) {
// eat any ex that escapes. We need to not depend on default logging.
}
// then
// both exceptions should be logged.
LogbackTestExtensionAssert.assertThat(myLogbackTestExtension).hasErrorMessage("HAPI-2544");
LogbackTestExtensionAssert.assertThat(myLogbackTestExtension).hasErrorMessage(SEARCH_TIMEOUT_ERROR);
}
private void initRequestMocks(String theURL) {
myHeaders = new HashMap<>(); myHeaders = new HashMap<>();
myHeaders.put(Constants.HEADER_CONTENT_TYPE, Constants.CT_FHIR_JSON_NEW); myHeaders.put(Constants.HEADER_CONTENT_TYPE, Constants.CT_FHIR_JSON_NEW);
when(myRequest.getRequestURI()).thenReturn("/Patient"); String relativeUri;
when(myRequest.getRequestURL()).thenReturn(new StringBuffer(ourServer.getBaseUrl() + "/Patient")); int endOfUri = theURL.indexOf("?");
if (endOfUri >= 0) {
relativeUri = theURL.substring(0,endOfUri);
} else {
relativeUri = theURL;
}
when(myRequest.getRequestURI()).thenReturn(relativeUri);
when(myRequest.getRequestURL()).thenReturn(new StringBuffer(ourServer.getBaseUrl() + theURL));
when(myRequest.getHeader(any())).thenAnswer(t -> { when(myRequest.getHeader(any())).thenAnswer(t -> {
String header = t.getArgument(0, String.class); String header = t.getArgument(0, String.class);
String value = myHeaders.get(header); String value = myHeaders.get(header);
@ -190,6 +233,11 @@ public class ServerConcurrencyTest {
.setOperationOutcome(oo); .setOperationOutcome(oo);
} }
@Search
public List<Patient> search() throws InterruptedException {
throw new InterruptedException(SEARCH_TIMEOUT_ERROR);
}
@Override @Override
public Class<Patient> getResourceType() { public Class<Patient> getResourceType() {