diff --git a/httpclient-cache/src/main/java/org/apache/http/impl/client/cache/ResponseProtocolCompliance.java b/httpclient-cache/src/main/java/org/apache/http/impl/client/cache/ResponseProtocolCompliance.java index e9fff6e26..3590aa8be 100644 --- a/httpclient-cache/src/main/java/org/apache/http/impl/client/cache/ResponseProtocolCompliance.java +++ b/httpclient-cache/src/main/java/org/apache/http/impl/client/cache/ResponseProtocolCompliance.java @@ -56,6 +56,8 @@ import org.apache.http.util.EntityUtils; @Immutable class ResponseProtocolCompliance { + private static final String UNEXPECTED_100_CONTINUE = "The incoming request did not contain a " + + "100-continue header, but the response was a Status 100, continue."; private static final String UNEXPECTED_PARTIAL_CONTENT = "partial content was returned for a request that did not ask for it"; /** @@ -207,26 +209,18 @@ class ResponseProtocolCompliance { } private void requestDidNotExpect100ContinueButResponseIsOne(HttpRequest request, - HttpResponse response) throws ClientProtocolException { + HttpResponse response) throws IOException { if (response.getStatusLine().getStatusCode() != HttpStatus.SC_CONTINUE) { return; } - - if (!requestWasWrapped(request)) { - return; - } - - ProtocolVersion originalProtocol = getOriginalRequestProtocol((RequestWrapper) request); - - if (originalProtocol.compareToVersion(HttpVersion.HTTP_1_1) >= 0) { - return; - } - - if (originalRequestDidNotExpectContinue((RequestWrapper) request)) { - throw new ClientProtocolException("The incoming request did not contain a " - + "100-continue header, but the response was a Status 100, continue."); - + + HttpRequest originalRequest = requestWasWrapped(request) ? + ((RequestWrapper)request).getOriginal() : request; + if (originalRequest instanceof HttpEntityEnclosingRequest) { + if (((HttpEntityEnclosingRequest)originalRequest).expectContinue()) return; } + consumeBody(response); + throw new ClientProtocolException(UNEXPECTED_100_CONTINUE); } private void transferEncodingIsNotReturnedTo1_0Client(HttpRequest request, HttpResponse response) { @@ -248,18 +242,6 @@ class ResponseProtocolCompliance { response.removeHeaders(HTTP.TRANSFER_ENCODING); } - private boolean originalRequestDidNotExpectContinue(RequestWrapper request) { - - try { - HttpEntityEnclosingRequest original = (HttpEntityEnclosingRequest) request - .getOriginal(); - - return !original.expectContinue(); - } catch (ClassCastException ex) { - return false; - } - } - private ProtocolVersion getOriginalRequestProtocol(RequestWrapper request) { return request.getOriginal().getProtocolVersion(); } diff --git a/httpclient-cache/src/test/java/org/apache/http/impl/client/cache/TestResponseProtocolCompliance.java b/httpclient-cache/src/test/java/org/apache/http/impl/client/cache/TestResponseProtocolCompliance.java index 932070fb7..72bc0b4ce 100644 --- a/httpclient-cache/src/test/java/org/apache/http/impl/client/cache/TestResponseProtocolCompliance.java +++ b/httpclient-cache/src/test/java/org/apache/http/impl/client/cache/TestResponseProtocolCompliance.java @@ -31,6 +31,8 @@ import static org.junit.Assert.*; import java.io.ByteArrayInputStream; import java.util.Date; +import org.apache.http.HttpEntity; +import org.apache.http.HttpEntityEnclosingRequest; import org.apache.http.HttpRequest; import org.apache.http.HttpResponse; import org.apache.http.HttpStatus; @@ -38,13 +40,14 @@ import org.apache.http.HttpVersion; import org.apache.http.client.ClientProtocolException; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpHead; +import org.apache.http.entity.ByteArrayEntity; import org.apache.http.entity.InputStreamEntity; import org.apache.http.impl.cookie.DateUtils; +import org.apache.http.message.BasicHttpEntityEnclosingRequest; import org.apache.http.message.BasicHttpResponse; import org.junit.Before; import org.junit.Test; - public class TestResponseProtocolCompliance { private ResponseProtocolCompliance impl; @@ -125,4 +128,25 @@ public class TestResponseProtocolCompliance { } assertTrue(closed.set || bais.read() == -1); } + + @Test + public void consumesBodyOf100ContinueResponseIfItArrives() throws Exception { + HttpEntityEnclosingRequest req = new BasicHttpEntityEnclosingRequest("POST", "/", HttpVersion.HTTP_1_1); + int nbytes = 128; + req.setHeader("Content-Length","" + nbytes); + req.setHeader("Content-Type", "application/octet-stream"); + HttpEntity postBody = new ByteArrayEntity(HttpTestUtils.getRandomBytes(nbytes)); + req.setEntity(postBody); + + HttpResponse resp = new BasicHttpResponse(HttpVersion.HTTP_1_1, HttpStatus.SC_CONTINUE, "Continue"); + final Flag closed = new Flag(); + ByteArrayInputStream bais = makeTrackableBody(nbytes, closed); + resp.setEntity(new InputStreamEntity(bais, -1)); + + try { + impl.ensureProtocolCompliance(req, resp); + } catch (ClientProtocolException expected) { + } + assertTrue(closed.set || bais.read() == -1); + } }