From 80266a20a2d50a7b6fa374b714da1bacf276c27d Mon Sep 17 00:00:00 2001 From: Michael Buckley Date: Wed, 2 Jun 2021 21:20:49 -0400 Subject: [PATCH] Pull request tracing out for sharing --- .../uhn/fhir/rest/server/RestfulServer.java | 17 ++--- .../rest/server/ServletRequestTracing.java | 65 +++++++++++++++++++ .../server/ServletRequestTracingTest.java | 57 ++++++++++++++++ 3 files changed, 127 insertions(+), 12 deletions(-) create mode 100644 hapi-fhir-server/src/main/java/ca/uhn/fhir/rest/server/ServletRequestTracing.java create mode 100644 hapi-fhir-server/src/test/java/ca/uhn/fhir/rest/server/ServletRequestTracingTest.java diff --git a/hapi-fhir-server/src/main/java/ca/uhn/fhir/rest/server/RestfulServer.java b/hapi-fhir-server/src/main/java/ca/uhn/fhir/rest/server/RestfulServer.java index 381e73f66a0..9df9b113306 100644 --- a/hapi-fhir-server/src/main/java/ca/uhn/fhir/rest/server/RestfulServer.java +++ b/hapi-fhir-server/src/main/java/ca/uhn/fhir/rest/server/RestfulServer.java @@ -1277,7 +1277,7 @@ public class RestfulServer extends HttpServlet implements IRestfulServer * Note that the generated request ID is a random 64-bit long integer encoded as @@ -1286,18 +1286,11 @@ public class RestfulServer extends HttpServlet implements IRestfulServer */ protected String getOrCreateRequestId(HttpServletRequest theRequest) { - String requestId = theRequest.getHeader(Constants.HEADER_REQUEST_ID); - if (isNotBlank(requestId)) { - for (char nextChar : requestId.toCharArray()) { - if (!Character.isLetterOrDigit(nextChar)) { - if (nextChar != '.' && nextChar != '-' && nextChar != '_' && nextChar != ' ') { - requestId = null; - break; - } - } - } - } + String requestId = ServletRequestTracing.maybeGetRequestId(theRequest); + // TODO can we delete this and newRequestId() + // and use ServletRequestTracing.getOrGenerateRequestId() instead? + // newRequestId() is protected. Do you think anyone actually overrode it? if (isBlank(requestId)) { int requestIdLength = Constants.REQUEST_ID_LENGTH; requestId = newRequestId(requestIdLength); diff --git a/hapi-fhir-server/src/main/java/ca/uhn/fhir/rest/server/ServletRequestTracing.java b/hapi-fhir-server/src/main/java/ca/uhn/fhir/rest/server/ServletRequestTracing.java new file mode 100644 index 00000000000..0516a676a59 --- /dev/null +++ b/hapi-fhir-server/src/main/java/ca/uhn/fhir/rest/server/ServletRequestTracing.java @@ -0,0 +1,65 @@ +package ca.uhn.fhir.rest.server; + +import ca.uhn.fhir.rest.api.Constants; +import org.apache.commons.lang3.RandomStringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; +import javax.servlet.ServletRequest; +import javax.servlet.http.HttpServletRequest; + +import static org.apache.commons.lang3.StringUtils.isBlank; +import static org.apache.commons.lang3.StringUtils.isNotBlank; + +public class ServletRequestTracing { + private static final Logger ourLog = LoggerFactory.getLogger(ServletRequestTracing.class); + public static final String ATTRIBUTE_REQUEST_ID = ServletRequestTracing.class.getName() + '.' + Constants.HEADER_REQUEST_ID; + + /** + * Assign a tracing id to this request, using + * the X-Request-ID if present and compatible. + * + * If none present, generate a 64 random alpha-numeric string that is not + * cryptographically secure. + * + * @param theServletRequest the request to trace + * @return the tracing id + */ + public static String getOrGenerateRequestId(ServletRequest theServletRequest) { + String requestId = maybeGetRequestId(theServletRequest); + if (isBlank(requestId)) { + requestId = RandomStringUtils.randomAlphanumeric(Constants.REQUEST_ID_LENGTH); + } + + ourLog.debug("Assigned tracing id {}", requestId); + + theServletRequest.setAttribute(ATTRIBUTE_REQUEST_ID, requestId); + + return requestId; + } + + @Nullable + public static String maybeGetRequestId(ServletRequest theServletRequest) { + // have we already seen this request? + String requestId = (String) theServletRequest.getAttribute(ATTRIBUTE_REQUEST_ID); + + if (requestId == null && theServletRequest instanceof HttpServletRequest) { + // Also applies to non-FHIR (e.g. admin-json) requests). + HttpServletRequest request = (HttpServletRequest) theServletRequest; + requestId = request.getHeader(Constants.HEADER_REQUEST_ID); + if (isNotBlank(requestId)) { + for (char nextChar : requestId.toCharArray()) { + if (!Character.isLetterOrDigit(nextChar)) { + if (nextChar != '.' && nextChar != '-' && nextChar != '_' && nextChar != ' ') { + requestId = null; + break; + } + } + } + } + } + return requestId; + } + +} diff --git a/hapi-fhir-server/src/test/java/ca/uhn/fhir/rest/server/ServletRequestTracingTest.java b/hapi-fhir-server/src/test/java/ca/uhn/fhir/rest/server/ServletRequestTracingTest.java new file mode 100644 index 00000000000..535540d3aea --- /dev/null +++ b/hapi-fhir-server/src/test/java/ca/uhn/fhir/rest/server/ServletRequestTracingTest.java @@ -0,0 +1,57 @@ +package ca.uhn.fhir.rest.server; + +import ca.uhn.fhir.rest.api.Constants; +import org.junit.jupiter.api.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.blankString; +import static org.hamcrest.Matchers.not; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ServletRequestTracingTest { + + MockHttpServletRequest myRequest = new MockHttpServletRequest(); + String myRequestIdResult; + + void run() { + myRequestIdResult = ServletRequestTracing.getOrGenerateRequestId(myRequest); + } + + @Test + public void emptyRequestGetsGeneratedId() { + // no setup + + run(); + + // verify + assertThat("id generated", myRequestIdResult, not(blankString())); + assertEquals(myRequest.getAttribute(ServletRequestTracing.ATTRIBUTE_REQUEST_ID),myRequestIdResult); + } + + @Test + public void requestWithCallerHapiIdUsesThat() { + // setup + myRequest.addHeader(Constants.HEADER_REQUEST_ID, "a_request_id"); + + run(); + + // verify + assertEquals("a_request_id", myRequestIdResult); + } + + @Test + public void duplicateCallsKeepsSameId() { + // no headers + + myRequestIdResult = ServletRequestTracing.getOrGenerateRequestId(myRequest); + + String secondResult = ServletRequestTracing.getOrGenerateRequestId(myRequest); + + // verify + assertThat("id generated", secondResult, not(blankString())); + assertEquals(myRequestIdResult, secondResult); + } + +}