diff --git a/httpclient-cache/src/main/java/org/apache/http/impl/client/cache/ResponseProtocolCompliance.java b/httpclient-cache/src/main/java/org/apache/http/impl/client/cache/ResponseProtocolCompliance.java index f935fcde1..e9fff6e26 100644 --- a/httpclient-cache/src/main/java/org/apache/http/impl/client/cache/ResponseProtocolCompliance.java +++ b/httpclient-cache/src/main/java/org/apache/http/impl/client/cache/ResponseProtocolCompliance.java @@ -56,6 +56,8 @@ import org.apache.http.util.EntityUtils; @Immutable class ResponseProtocolCompliance { + private static final String UNEXPECTED_PARTIAL_CONTENT = "partial content was returned for a request that did not ask for it"; + /** * When we get a response from a down stream server (Origin Server) * we attempt to see if it is HTTP 1.1 Compliant and if not, attempt to @@ -68,8 +70,7 @@ class ResponseProtocolCompliance { public void ensureProtocolCompliance(HttpRequest request, HttpResponse response) throws IOException { if (backendResponseMustNotHaveBody(request, response)) { - HttpEntity body = response.getEntity(); - if (body != null) EntityUtils.consume(body); + consumeBody(response); response.setEntity(null); } @@ -90,6 +91,11 @@ class ResponseProtocolCompliance { warningsWithNonMatchingWarnDatesAreRemoved(response); } + private void consumeBody(HttpResponse response) throws IOException { + HttpEntity body = response.getEntity(); + if (body != null) EntityUtils.consume(body); + } + private void warningsWithNonMatchingWarnDatesAreRemoved( HttpResponse response) { Date responseDate = null; @@ -157,15 +163,13 @@ class ResponseProtocolCompliance { } private void ensurePartialContentIsNotSentToAClientThatDidNotRequestIt(HttpRequest request, - HttpResponse response) throws ClientProtocolException { - if (request.getFirstHeader(HeaderConstants.RANGE) != null) + HttpResponse response) throws IOException { + if (request.getFirstHeader(HeaderConstants.RANGE) != null + || response.getStatusLine().getStatusCode() != HttpStatus.SC_PARTIAL_CONTENT) return; - - if (response.getFirstHeader(HeaderConstants.CONTENT_RANGE) != null) { - throw new ClientProtocolException( - "Content-Range was returned for a request that did not ask for a Content-Range."); - } - + + consumeBody(response); + throw new ClientProtocolException(UNEXPECTED_PARTIAL_CONTENT); } private void ensure200ForOPTIONSRequestWithNoBodyHasContentLengthZero(HttpRequest request, diff --git a/httpclient-cache/src/test/java/org/apache/http/impl/client/cache/TestResponseProtocolCompliance.java b/httpclient-cache/src/test/java/org/apache/http/impl/client/cache/TestResponseProtocolCompliance.java index 6326963bf..932070fb7 100644 --- a/httpclient-cache/src/test/java/org/apache/http/impl/client/cache/TestResponseProtocolCompliance.java +++ b/httpclient-cache/src/test/java/org/apache/http/impl/client/cache/TestResponseProtocolCompliance.java @@ -35,6 +35,8 @@ import org.apache.http.HttpRequest; import org.apache.http.HttpResponse; import org.apache.http.HttpStatus; import org.apache.http.HttpVersion; +import org.apache.http.client.ClientProtocolException; +import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpHead; import org.apache.http.entity.InputStreamEntity; import org.apache.http.impl.cookie.DateUtils; @@ -51,32 +53,76 @@ public class TestResponseProtocolCompliance { public void setUp() { impl = new ResponseProtocolCompliance(); } + + private static class Flag { + public boolean set; + } - @Test - public void consumesBodyIfOriginSendsOneInResponseToHEAD() throws Exception { - HttpRequest req = new HttpHead("http://foo.example.com/"); - HttpResponse resp = new BasicHttpResponse(HttpVersion.HTTP_1_1, HttpStatus.SC_OK, "OK"); + private void setMinimalResponseHeaders(HttpResponse resp) { resp.setHeader("Date", DateUtils.formatDate(new Date())); resp.setHeader("Server", "MyServer/1.0"); - - int nbytes = 128; - resp.setHeader("Content-Length","" + nbytes); + } + + private ByteArrayInputStream makeTrackableBody(int nbytes, final Flag closed) { byte[] buf = HttpTestUtils.getRandomBytes(nbytes); - final Flag closed = new Flag(); ByteArrayInputStream bais = new ByteArrayInputStream(buf) { @Override public void close() { closed.set = true; } }; + return bais; + } + + private HttpResponse makePartialResponse(int nbytes) { + HttpResponse resp = new BasicHttpResponse(HttpVersion.HTTP_1_1, HttpStatus.SC_PARTIAL_CONTENT, "Partial Content"); + setMinimalResponseHeaders(resp); + resp.setHeader("Content-Length","" + nbytes); + resp.setHeader("Content-Range","0-127/256"); + return resp; + } + + @Test + public void consumesBodyIfOriginSendsOneInResponseToHEAD() throws Exception { + HttpRequest req = new HttpHead("http://foo.example.com/"); + int nbytes = 128; + HttpResponse resp = new BasicHttpResponse(HttpVersion.HTTP_1_1, HttpStatus.SC_OK, "OK"); + setMinimalResponseHeaders(resp); + resp.setHeader("Content-Length","" + nbytes); + + final Flag closed = new Flag(); + ByteArrayInputStream bais = makeTrackableBody(nbytes, closed); resp.setEntity(new InputStreamEntity(bais, -1)); impl.ensureProtocolCompliance(req, resp); assertNull(resp.getEntity()); assertTrue(closed.set || bais.read() == -1); } - - private static class Flag { - public boolean set; + + @Test(expected=ClientProtocolException.class) + public void throwsExceptionIfOriginReturnsPartialResponseWhenNotRequested() throws Exception { + HttpRequest req = new HttpGet("http://foo.example.com/"); + int nbytes = 128; + HttpResponse resp = makePartialResponse(nbytes); + resp.setEntity(HttpTestUtils.makeBody(nbytes)); + + impl.ensureProtocolCompliance(req, resp); + } + + @Test + public void consumesPartialContentFromOriginEvenIfNotRequested() throws Exception { + HttpRequest req = new HttpGet("http://foo.example.com/"); + int nbytes = 128; + HttpResponse resp = makePartialResponse(nbytes); + + final Flag closed = new Flag(); + ByteArrayInputStream bais = makeTrackableBody(nbytes, closed); + resp.setEntity(new InputStreamEntity(bais, -1)); + + try { + impl.ensureProtocolCompliance(req, resp); + } catch (ClientProtocolException expected) { + } + assertTrue(closed.set || bais.read() == -1); } }