diff --git a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java index ba0803867..a78cf8601 100644 --- a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java +++ b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java @@ -56,7 +56,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; @@ -77,6 +79,7 @@ import okio.Timeout; import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_AT_START; import static com.squareup.okhttp.mockwebserver.SocketPolicy.FAIL_HANDSHAKE; +import static java.util.concurrent.TimeUnit.SECONDS; /** * A scriptable web server. Callers supply canned responses and the server @@ -636,9 +639,15 @@ public final class MockWebServer { final WebSocketListener listener = response.getWebSocketListener(); final CountDownLatch connectionClose = new CountDownLatch(1); + + ThreadPoolExecutor replyExecutor = + new ThreadPoolExecutor(1, 1, 1, SECONDS, new LinkedBlockingDeque(), + Util.threadFactory(String.format("MockWebServer %s WebSocket", request.getPath()), + true)); + replyExecutor.allowCoreThreadTimeOut(true); final RealWebSocket webSocket = - new RealWebSocket(false, source, sink, new SecureRandom(), listener, - request.getPath()) { + new RealWebSocket(false /* is server */, source, sink, new SecureRandom(), replyExecutor, + listener, request.getPath()) { @Override protected void closeConnection() throws IOException { connectionClose.countDown(); } diff --git a/okhttp-tests/src/test/java/com/squareup/okhttp/internal/ws/RealWebSocketTest.java b/okhttp-tests/src/test/java/com/squareup/okhttp/internal/ws/RealWebSocketTest.java index 1395e2bf5..401c0cb63 100644 --- a/okhttp-tests/src/test/java/com/squareup/okhttp/internal/ws/RealWebSocketTest.java +++ b/okhttp-tests/src/test/java/com/squareup/okhttp/internal/ws/RealWebSocketTest.java @@ -18,6 +18,7 @@ package com.squareup.okhttp.internal.ws; import java.io.IOException; import java.net.ProtocolException; import java.util.Random; +import java.util.concurrent.Executor; import okio.Buffer; import okio.BufferedSink; import okio.ByteString; @@ -50,7 +51,15 @@ public final class RealWebSocketTest { @Before public void setUp() { Random random = new Random(0); String url = "http://example.com/websocket"; - client = new RealWebSocket(true, server2client, client2Server, random, clientListener, url) { + + Executor synchronousExecutor = new Executor() { + @Override public void execute(Runnable command) { + command.run(); + } + }; + + client = new RealWebSocket(true, server2client, client2Server, random, synchronousExecutor, + clientListener, url) { @Override protected void closeConnection() throws IOException { clientConnectionClosed = true; if (clientConnectionCloseThrows) { @@ -58,7 +67,8 @@ public final class RealWebSocketTest { } } }; - server = new RealWebSocket(false, client2Server, server2client, random, serverListener, url) { + server = new RealWebSocket(false, client2Server, server2client, random, synchronousExecutor, + serverListener, url) { @Override protected void closeConnection() throws IOException { } }; @@ -98,16 +108,14 @@ public final class RealWebSocketTest { sink.close(); server.readMessage(); serverListener.assertTextMessage("Hello!"); - Thread.sleep(1000); // Wait for pong to be written. client.readMessage(); clientListener.assertPong(new Buffer().writeUtf8("Pong?")); } @Test public void pingWritesPong() throws IOException, InterruptedException { client.sendPing(new Buffer().writeUtf8("Hello!")); - server.readMessage(); // Read the ping, enqueue the pong. - Thread.sleep(1000); // Wait for pong to be written. - client.readMessage(); + server.readMessage(); // Read the ping, write the pong. + client.readMessage(); // Read the pong. clientListener.assertPong(new Buffer().writeUtf8("Hello!")); } diff --git a/okhttp/src/main/java/com/squareup/okhttp/internal/ws/RealWebSocket.java b/okhttp/src/main/java/com/squareup/okhttp/internal/ws/RealWebSocket.java index 75ab60892..d7fc5115a 100644 --- a/okhttp/src/main/java/com/squareup/okhttp/internal/ws/RealWebSocket.java +++ b/okhttp/src/main/java/com/squareup/okhttp/internal/ws/RealWebSocket.java @@ -16,18 +16,15 @@ package com.squareup.okhttp.internal.ws; import com.squareup.okhttp.internal.NamedRunnable; -import com.squareup.okhttp.internal.Util; import java.io.IOException; import java.net.ProtocolException; import java.util.Random; -import java.util.concurrent.LinkedBlockingDeque; -import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.Executor; import okio.Buffer; import okio.BufferedSink; import okio.BufferedSource; import static com.squareup.okhttp.internal.ws.WebSocketReader.FrameCallback; -import static java.util.concurrent.TimeUnit.SECONDS; public abstract class RealWebSocket implements WebSocket { /** A close code which indicates that the peer encountered a protocol exception. */ @@ -45,15 +42,9 @@ public abstract class RealWebSocket implements WebSocket { private final Object closeLock = new Object(); public RealWebSocket(boolean isClient, BufferedSource source, BufferedSink sink, Random random, - final WebSocketListener listener, final String url) { + final Executor replyExecutor, final WebSocketListener listener, final String url) { this.listener = listener; - // Pings come in on the reader thread. This executor contends with callers for writing pongs. - final ThreadPoolExecutor pongExecutor = new ThreadPoolExecutor(1, 1, 1, SECONDS, - new LinkedBlockingDeque(), - Util.threadFactory(String.format("OkHttp %s WebSocket", url), true)); - pongExecutor.allowCoreThreadTimeOut(true); - writer = new WebSocketWriter(isClient, sink, random); reader = new WebSocketReader(isClient, source, new FrameCallback() { @Override public void onMessage(BufferedSource source, PayloadType type) throws IOException { @@ -61,7 +52,7 @@ public abstract class RealWebSocket implements WebSocket { } @Override public void onPing(final Buffer buffer) { - pongExecutor.execute(new NamedRunnable("OkHttp %s WebSocket Pong", url) { + replyExecutor.execute(new NamedRunnable("OkHttp %s WebSocket Pong Reply", url) { @Override protected void execute() { try { writer.writePong(buffer); @@ -75,8 +66,12 @@ public abstract class RealWebSocket implements WebSocket { listener.onPong(buffer); } - @Override public void onClose(int code, String reason) { - peerClose(code, reason); + @Override public void onClose(final int code, final String reason) { + replyExecutor.execute(new NamedRunnable("OkHttp %s WebSocket Close Reply", url) { + @Override protected void execute() { + peerClose(code, reason); + } + }); } }); } @@ -134,7 +129,7 @@ public abstract class RealWebSocket implements WebSocket { } } - /** Called on the reader thread when a close frame is encountered. */ + /** Replies and closes this web socket when a close frame is read from the peer. */ private void peerClose(int code, String reason) { boolean writeCloseResponse; synchronized (closeLock) { @@ -146,7 +141,6 @@ public abstract class RealWebSocket implements WebSocket { if (writeCloseResponse) { try { - // The reader thread will read no more frames so use it to send the response. writer.writeClose(code, reason); } catch (IOException ignored) { } diff --git a/okhttp/src/main/java/com/squareup/okhttp/internal/ws/WebSocketCall.java b/okhttp/src/main/java/com/squareup/okhttp/internal/ws/WebSocketCall.java index 9147e5bed..70a85b0d8 100644 --- a/okhttp/src/main/java/com/squareup/okhttp/internal/ws/WebSocketCall.java +++ b/okhttp/src/main/java/com/squareup/okhttp/internal/ws/WebSocketCall.java @@ -30,11 +30,16 @@ import java.net.Socket; import java.security.SecureRandom; import java.util.Collections; import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.ThreadPoolExecutor; import okio.BufferedSink; import okio.BufferedSource; import okio.ByteString; import okio.Okio; +import static java.util.concurrent.TimeUnit.SECONDS; + // TODO move to public API! public class WebSocketCall { /** @@ -175,7 +180,7 @@ public class WebSocketCall { BufferedSink sink = Okio.buffer(Okio.sink(socket)); final RealWebSocket webSocket = - new ConnectionWebSocket(response, connection, source, sink, random, listener); + ConnectionWebSocket.create(response, connection, source, sink, random, listener); // Start a dedicated thread for reading the web socket. new Thread(new NamedRunnable("OkHttp WebSocket reader %s", request.urlString()) { @@ -193,11 +198,23 @@ public class WebSocketCall { // Keep static so that the WebSocketCall instance can be garbage collected. private static class ConnectionWebSocket extends RealWebSocket { + static RealWebSocket create(Response response, Connection connection, BufferedSource source, + BufferedSink sink, Random random, WebSocketListener listener) { + String url = response.request().urlString(); + ThreadPoolExecutor replyExecutor = + new ThreadPoolExecutor(1, 1, 1, SECONDS, new LinkedBlockingDeque(), + Util.threadFactory(String.format("OkHttp %s WebSocket", url), true)); + replyExecutor.allowCoreThreadTimeOut(true); + + return new ConnectionWebSocket(connection, source, sink, random, replyExecutor, listener, + url); + } + private final Connection connection; - public ConnectionWebSocket(Response response, Connection connection, BufferedSource source, - BufferedSink sink, Random random, WebSocketListener listener) { - super(true /* is client */, source, sink, random, listener, response.request().urlString()); + private ConnectionWebSocket(Connection connection, BufferedSource source, BufferedSink sink, + Random random, Executor replyExecutor, WebSocketListener listener, String url) { + super(true /* is client */, source, sink, random, replyExecutor, listener, url); this.connection = connection; }