From ee87f8036f99fdfd63f8ef145f4478907682db29 Mon Sep 17 00:00:00 2001 From: jwilson Date: Sun, 29 Nov 2015 22:58:45 -0500 Subject: [PATCH] Change async cancel to cancel the raw socket only. Previously we could close an SSL socket which does synchronous I/O. This made it unreasonable to cancel a call on a UI thread. Closes: https://github.com/square/okhttp/issues/1592 --- .../java/com/squareup/okhttp/CallTest.java | 142 ++++++++++++------ .../okhttp/internal/http/Http2xStream.java | 2 +- .../okhttp/internal/io/RealConnection.java | 26 +++- 3 files changed, 119 insertions(+), 51 deletions(-) diff --git a/okhttp-tests/src/test/java/com/squareup/okhttp/CallTest.java b/okhttp-tests/src/test/java/com/squareup/okhttp/CallTest.java index 6d27939ab..f2f97d0cb 100644 --- a/okhttp-tests/src/test/java/com/squareup/okhttp/CallTest.java +++ b/okhttp-tests/src/test/java/com/squareup/okhttp/CallTest.java @@ -172,13 +172,18 @@ public final class CallTest { .assertNotSuccessful(); } - @Test public void get_SPDY_3() throws Exception { - enableProtocol(Protocol.SPDY_3); + @Test public void get_HTTP_2() throws Exception { + enableProtocol(Protocol.HTTP_2); get(); } - @Test public void get_HTTP_2() throws Exception { - enableProtocol(Protocol.HTTP_2); + @Test public void get_HTTPS() throws Exception { + enableTls(); + get(); + } + + @Test public void get_SPDY_3() throws Exception { + enableProtocol(Protocol.SPDY_3); get(); } @@ -241,8 +246,8 @@ public final class CallTest { assertNull(recordedRequest.getHeader("Content-Length")); } - @Test public void head_SPDY_3() throws Exception { - enableProtocol(Protocol.SPDY_3); + @Test public void head_HTTPS() throws Exception { + enableTls(); head(); } @@ -251,6 +256,11 @@ public final class CallTest { head(); } + @Test public void head_SPDY_3() throws Exception { + enableProtocol(Protocol.SPDY_3); + head(); + } + @Test public void post() throws Exception { server.enqueue(new MockResponse().setBody("abc")); @@ -270,8 +280,8 @@ public final class CallTest { assertEquals("text/plain; charset=utf-8", recordedRequest.getHeader("Content-Type")); } - @Test public void post_SPDY_3() throws Exception { - enableProtocol(Protocol.SPDY_3); + @Test public void post_HTTPS() throws Exception { + enableTls(); post(); } @@ -280,6 +290,11 @@ public final class CallTest { post(); } + @Test public void post_SPDY_3() throws Exception { + enableProtocol(Protocol.SPDY_3); + post(); + } + @Test public void postZeroLength() throws Exception { server.enqueue(new MockResponse().setBody("abc")); @@ -299,8 +314,8 @@ public final class CallTest { assertEquals(null, recordedRequest.getHeader("Content-Type")); } - @Test public void postZeroLength_SPDY_3() throws Exception { - enableProtocol(Protocol.SPDY_3); + @Test public void postZerolength_HTTPS() throws Exception { + enableTls(); postZeroLength(); } @@ -309,12 +324,17 @@ public final class CallTest { postZeroLength(); } + @Test public void postZeroLength_SPDY_3() throws Exception { + enableProtocol(Protocol.SPDY_3); + postZeroLength(); + } + @Test public void postBodyRetransmittedAfterAuthorizationFail() throws Exception { postBodyRetransmittedAfterAuthorizationFail("abc"); } - @Test public void postBodyRetransmittedAfterAuthorizationFail_SPDY_3() throws Exception { - enableProtocol(Protocol.SPDY_3); + @Test public void postBodyRetransmittedAfterAuthorizationFail_HTTPS() throws Exception { + enableTls(); postBodyRetransmittedAfterAuthorizationFail("abc"); } @@ -323,13 +343,18 @@ public final class CallTest { postBodyRetransmittedAfterAuthorizationFail("abc"); } + @Test public void postBodyRetransmittedAfterAuthorizationFail_SPDY_3() throws Exception { + enableProtocol(Protocol.SPDY_3); + postBodyRetransmittedAfterAuthorizationFail("abc"); + } + /** Don't explode when resending an empty post. https://github.com/square/okhttp/issues/1131 */ @Test public void postEmptyBodyRetransmittedAfterAuthorizationFail() throws Exception { postBodyRetransmittedAfterAuthorizationFail(""); } - @Test public void postEmptyBodyRetransmittedAfterAuthorizationFail_SPDY_3() throws Exception { - enableProtocol(Protocol.SPDY_3); + @Test public void postEmptyBodyRetransmittedAfterAuthorizationFail_HTTPS() throws Exception { + enableTls(); postBodyRetransmittedAfterAuthorizationFail(""); } @@ -338,6 +363,11 @@ public final class CallTest { postBodyRetransmittedAfterAuthorizationFail(""); } + @Test public void postEmptyBodyRetransmittedAfterAuthorizationFail_SPDY_3() throws Exception { + enableProtocol(Protocol.SPDY_3); + postBodyRetransmittedAfterAuthorizationFail(""); + } + private void postBodyRetransmittedAfterAuthorizationFail(String body) throws Exception { server.enqueue(new MockResponse().setResponseCode(401)); server.enqueue(new MockResponse()); @@ -414,8 +444,8 @@ public final class CallTest { assertEquals(null, recordedRequest.getHeader("Content-Type")); } - @Test public void delete_SPDY_3() throws Exception { - enableProtocol(Protocol.SPDY_3); + @Test public void delete_HTTPS() throws Exception { + enableTls(); delete(); } @@ -424,6 +454,11 @@ public final class CallTest { delete(); } + @Test public void delete_SPDY_3() throws Exception { + enableProtocol(Protocol.SPDY_3); + delete(); + } + @Test public void deleteWithRequestBody() throws Exception { server.enqueue(new MockResponse().setBody("abc")); @@ -460,8 +495,8 @@ public final class CallTest { assertEquals("text/plain; charset=utf-8", recordedRequest.getHeader("Content-Type")); } - @Test public void put_SPDY_3() throws Exception { - enableProtocol(Protocol.SPDY_3); + @Test public void put_HTTPS() throws Exception { + enableTls(); put(); } @@ -470,6 +505,11 @@ public final class CallTest { put(); } + @Test public void put_SPDY_3() throws Exception { + enableProtocol(Protocol.SPDY_3); + put(); + } + @Test public void patch() throws Exception { server.enqueue(new MockResponse().setBody("abc")); @@ -489,13 +529,18 @@ public final class CallTest { assertEquals("text/plain; charset=utf-8", recordedRequest.getHeader("Content-Type")); } - @Test public void patch_SPDY_3() throws Exception { - enableProtocol(Protocol.SPDY_3); + @Test public void patch_HTTP_2() throws Exception { + enableProtocol(Protocol.HTTP_2); patch(); } - @Test public void patch_HTTP_2() throws Exception { - enableProtocol(Protocol.HTTP_2); + @Test public void patch_HTTPS() throws Exception { + enableTls(); + patch(); + } + + @Test public void patch_SPDY_3() throws Exception { + enableProtocol(Protocol.SPDY_3); patch(); } @@ -807,27 +852,21 @@ public final class CallTest { } @Test public void tls() throws Exception { - server.useHttps(sslContext.getSocketFactory(), false); + enableTls(); server.enqueue(new MockResponse() .setBody("abc") .addHeader("Content-Type: text/plain")); - client.setSslSocketFactory(sslContext.getSocketFactory()); - client.setHostnameVerifier(new RecordingHostnameVerifier()); - executeSynchronously(new Request.Builder().url(server.url("/")).build()) .assertHandshake(); } @Test public void tls_Async() throws Exception { - server.useHttps(sslContext.getSocketFactory(), false); + enableTls(); server.enqueue(new MockResponse() .setBody("abc") .addHeader("Content-Type: text/plain")); - client.setSslSocketFactory(sslContext.getSocketFactory()); - client.setHostnameVerifier(new RecordingHostnameVerifier()); - Request request = new Request.Builder() .url(server.url("/")) .build(); @@ -966,12 +1005,10 @@ public final class CallTest { } @Test public void setFollowSslRedirectsFalse() throws Exception { - server.useHttps(sslContext.getSocketFactory(), false); + enableTls(); server.enqueue(new MockResponse().setResponseCode(301).addHeader("Location: http://square.com")); client.setFollowSslRedirects(false); - client.setSslSocketFactory(sslContext.getSocketFactory()); - client.setHostnameVerifier(new RecordingHostnameVerifier()); Request request = new Request.Builder().url(server.url("/")).build(); Response response = client.newCall(request).execute(); @@ -979,13 +1016,10 @@ public final class CallTest { } @Test public void matchingPinnedCertificate() throws Exception { - server.useHttps(sslContext.getSocketFactory(), false); + enableTls(); server.enqueue(new MockResponse()); server.enqueue(new MockResponse()); - client.setSslSocketFactory(sslContext.getSocketFactory()); - client.setHostnameVerifier(new RecordingHostnameVerifier()); - // Make a first request without certificate pinning. Use it to collect certificates to pin. Request request1 = new Request.Builder().url(server.url("/")).build(); Response response1 = client.newCall(request1).execute(); @@ -1002,12 +1036,9 @@ public final class CallTest { } @Test public void unmatchingPinnedCertificate() throws Exception { - server.useHttps(sslContext.getSocketFactory(), false); + enableTls(); server.enqueue(new MockResponse()); - client.setSslSocketFactory(sslContext.getSocketFactory()); - client.setHostnameVerifier(new RecordingHostnameVerifier()); - // Pin publicobject.com's cert. client.setCertificatePinner(new CertificatePinner.Builder() .add(server.getHostName(), "sha1/DmxUShsZuNiqPQsX2Oi9uv2sCnw=") @@ -1605,6 +1636,11 @@ public final class CallTest { } } + @Test public void cancelInFlightBeforeResponseReadThrowsIOE_HTTPS() throws Exception { + enableTls(); + cancelInFlightBeforeResponseReadThrowsIOE(); + } + @Test public void cancelInFlightBeforeResponseReadThrowsIOE_HTTP_2() throws Exception { enableProtocol(Protocol.HTTP_2); cancelInFlightBeforeResponseReadThrowsIOE(); @@ -1642,6 +1678,11 @@ public final class CallTest { callback.await(requestB.httpUrl()).assertFailure("Canceled"); } + @Test public void canceledBeforeIOSignalsOnFailure_HTTPS() throws Exception { + enableTls(); + canceledBeforeIOSignalsOnFailure(); + } + @Test public void canceledBeforeIOSignalsOnFailure_HTTP_2() throws Exception { enableProtocol(Protocol.HTTP_2); canceledBeforeIOSignalsOnFailure(); @@ -1669,6 +1710,11 @@ public final class CallTest { "Socket closed"); } + @Test public void canceledBeforeResponseReadSignalsOnFailure_HTTPS() throws Exception { + enableTls(); + canceledBeforeResponseReadSignalsOnFailure(); + } + @Test public void canceledBeforeResponseReadSignalsOnFailure_HTTP_2() throws Exception { enableProtocol(Protocol.HTTP_2); canceledBeforeResponseReadSignalsOnFailure(); @@ -1716,6 +1762,12 @@ public final class CallTest { assertFalse(failureRef.get()); } + @Test public void canceledAfterResponseIsDeliveredBreaksStreamButSignalsOnce_HTTPS() + throws Exception { + enableTls(); + canceledAfterResponseIsDeliveredBreaksStreamButSignalsOnce(); + } + @Test public void canceledAfterResponseIsDeliveredBreaksStreamButSignalsOnce_HTTP_2() throws Exception { enableProtocol(Protocol.HTTP_2); @@ -2087,11 +2139,15 @@ public final class CallTest { * -Xbootclasspath/p:/tmp/alpn-boot-8.0.0.v20140317} */ private void enableProtocol(Protocol protocol) { + enableTls(); + client.setProtocols(Arrays.asList(protocol, Protocol.HTTP_1_1)); + server.setProtocols(client.getProtocols()); + } + + private void enableTls() { client.setSslSocketFactory(sslContext.getSocketFactory()); client.setHostnameVerifier(new RecordingHostnameVerifier()); - client.setProtocols(Arrays.asList(protocol, Protocol.HTTP_1_1)); server.useHttps(sslContext.getSocketFactory(), false); - server.setProtocols(client.getProtocols()); } private Buffer gzip(String data) throws IOException { diff --git a/okhttp/src/main/java/com/squareup/okhttp/internal/http/Http2xStream.java b/okhttp/src/main/java/com/squareup/okhttp/internal/http/Http2xStream.java index 040e258b3..6b8b68f5c 100644 --- a/okhttp/src/main/java/com/squareup/okhttp/internal/http/Http2xStream.java +++ b/okhttp/src/main/java/com/squareup/okhttp/internal/http/Http2xStream.java @@ -165,7 +165,7 @@ public final class Http2xStream implements HttpStream { result.add(new Header(TARGET_HOST, Util.hostHeader(request.httpUrl()))); result.add(new Header(TARGET_SCHEME, request.httpUrl().scheme())); - Set names = new LinkedHashSet(); + Set names = new LinkedHashSet<>(); for (int i = 0, size = headers.size(); i < size; i++) { // header names must be lowercase. ByteString name = ByteString.encodeUtf8(headers.name(i).toLowerCase(Locale.US)); diff --git a/okhttp/src/main/java/com/squareup/okhttp/internal/io/RealConnection.java b/okhttp/src/main/java/com/squareup/okhttp/internal/io/RealConnection.java index f2e443a1e..77bfbee38 100644 --- a/okhttp/src/main/java/com/squareup/okhttp/internal/io/RealConnection.java +++ b/okhttp/src/main/java/com/squareup/okhttp/internal/io/RealConnection.java @@ -59,6 +59,14 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; public final class RealConnection implements Connection { private final Route route; + + /** The low-level TCP socket. */ + private Socket rawSocket; + + /** + * The application layer socket. Either an {@link SSLSocket} layered over {@link #rawSocket}, or + * {@link #rawSocket} itself if this connection does not use SSL. + */ private Socket socket; private Handshake handshake; private Protocol protocol; @@ -92,13 +100,15 @@ public final class RealConnection implements Connection { while (protocol == null) { try { - socket = proxy.type() == Proxy.Type.DIRECT || proxy.type() == Proxy.Type.HTTP + rawSocket = proxy.type() == Proxy.Type.DIRECT || proxy.type() == Proxy.Type.HTTP ? address.getSocketFactory().createSocket() : new Socket(proxy); connectSocket(connectTimeout, readTimeout, writeTimeout, connectionSpecSelector); } catch (IOException e) { Util.closeQuietly(socket); + Util.closeQuietly(rawSocket); socket = null; + rawSocket = null; source = null; sink = null; handshake = null; @@ -121,19 +131,20 @@ public final class RealConnection implements Connection { /** Does all the work necessary to build a full HTTP or HTTPS connection on a raw socket. */ private void connectSocket(int connectTimeout, int readTimeout, int writeTimeout, ConnectionSpecSelector connectionSpecSelector) throws IOException { - socket.setSoTimeout(readTimeout); + rawSocket.setSoTimeout(readTimeout); try { - Platform.get().connectSocket(socket, route.getSocketAddress(), connectTimeout); + Platform.get().connectSocket(rawSocket, route.getSocketAddress(), connectTimeout); } catch (ConnectException e) { throw new ConnectException("Failed to connect to " + route.getSocketAddress()); } - source = Okio.buffer(Okio.source(socket)); - sink = Okio.buffer(Okio.sink(socket)); + source = Okio.buffer(Okio.source(rawSocket)); + sink = Okio.buffer(Okio.sink(rawSocket)); if (route.getAddress().getSslSocketFactory() != null) { connectTls(readTimeout, writeTimeout, connectionSpecSelector); } else { protocol = Protocol.HTTP_1_1; + socket = rawSocket; } if (protocol == Protocol.SPDY_3 || protocol == Protocol.HTTP_2) { @@ -157,7 +168,7 @@ public final class RealConnection implements Connection { try { // Create the wrapper over the connected socket. sslSocket = (SSLSocket) sslSocketFactory.createSocket( - socket, address.getUriHost(), address.getUriPort(), true /* autoClose */); + rawSocket, address.getUriHost(), address.getUriPort(), true /* autoClose */); // Configure the socket's ciphers, TLS versions, and extensions. ConnectionSpec connectionSpec = connectionSpecSelector.configureSecureSocket(sslSocket); @@ -285,7 +296,8 @@ public final class RealConnection implements Connection { } public void cancel() { - Util.closeQuietly(socket); + // Close the raw socket so we don't end up doing synchronous I/O. + Util.closeQuietly(rawSocket); } @Override public Socket getSocket() {