diff --git a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java index 0e4c5c543..8cd5d945b 100644 --- a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java +++ b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java @@ -670,19 +670,8 @@ public final class MockWebServer { .protocol(Protocol.HTTP_1_1) .build(); - // The callback might act synchronously. Give it its own thread. - new Thread(new Runnable() { - @Override public void run() { - try { - listener.onOpen(webSocket, fancyResponse); - } catch (IOException e) { - // TODO try to write close frame? - connectionClose.countDown(); - } - } - }, "MockWebServer WebSocket Writer " + request.getPath()).start(); + listener.onOpen(webSocket, fancyResponse); - // Use this thread to continuously read messages. while (webSocket.readMessage()) { } diff --git a/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/AutobahnTester.java b/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/AutobahnTester.java index e3b73f8fb..a592624c6 100644 --- a/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/AutobahnTester.java +++ b/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/AutobahnTester.java @@ -69,8 +69,7 @@ public final class AutobahnTester { private final ExecutorService sendExecutor = Executors.newSingleThreadExecutor(); private WebSocket webSocket; - @Override public void onOpen(WebSocket webSocket, Response response) - throws IOException { + @Override public void onOpen(WebSocket webSocket, Response response) { System.out.println("Executing test case " + number + "/" + count); this.webSocket = webSocket; } @@ -118,8 +117,7 @@ public final class AutobahnTester { final AtomicLong countRef = new AtomicLong(); final AtomicReference failureRef = new AtomicReference<>(); newWebSocket("/getCaseCount").enqueue(new WebSocketListener() { - @Override public void onOpen(WebSocket webSocket, Response response) - throws IOException { + @Override public void onOpen(WebSocket webSocket, Response response) { } @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type) @@ -157,8 +155,7 @@ public final class AutobahnTester { private void updateReports() { final CountDownLatch latch = new CountDownLatch(1); newWebSocket("/updateReports?agent=" + Version.userAgent()).enqueue(new WebSocketListener() { - @Override public void onOpen(WebSocket webSocket, Response response) - throws IOException { + @Override public void onOpen(WebSocket webSocket, Response response) { } @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type) diff --git a/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/WebSocketCallTest.java b/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/WebSocketCallTest.java index 02b8f9f38..abbfc60f6 100644 --- a/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/WebSocketCallTest.java +++ b/okhttp-ws-tests/src/test/java/com/squareup/okhttp/ws/WebSocketCallTest.java @@ -70,9 +70,16 @@ public final class WebSocketCallTest { @Test public void serverMessage() throws IOException { WebSocketListener serverListener = new EmptyWebSocketListener() { - @Override public void onOpen(WebSocket webSocket, Response response) - throws IOException { - webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!")); + @Override public void onOpen(final WebSocket webSocket, Response response) { + new Thread() { + @Override public void run() { + try { + webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!")); + } catch (IOException e) { + throw new AssertionError(e); + } + } + }.start(); } }; server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); @@ -96,12 +103,19 @@ public final class WebSocketCallTest { @Test public void serverStreamingMessage() throws IOException { WebSocketListener serverListener = new EmptyWebSocketListener() { - @Override public void onOpen(WebSocket webSocket, Response response) - throws IOException { - BufferedSink sink = webSocket.newMessageSink(TEXT); - sink.writeUtf8("Hello, ").flush(); - sink.writeUtf8("WebSockets!").flush(); - sink.close(); + @Override public void onOpen(final WebSocket webSocket, Response response) { + new Thread() { + @Override public void run() { + try { + BufferedSink sink = webSocket.newMessageSink(TEXT); + sink.writeUtf8("Hello, ").flush(); + sink.writeUtf8("WebSockets!").flush(); + sink.close(); + } catch (IOException e) { + throw new AssertionError(e); + } + } + }.start(); } }; server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); @@ -235,8 +249,7 @@ public final class WebSocketCallTest { final AtomicReference failureRef = new AtomicReference<>(); final CountDownLatch latch = new CountDownLatch(1); call.enqueue(new WebSocketListener() { - @Override public void onOpen(WebSocket webSocket, Response response) - throws IOException { + @Override public void onOpen(WebSocket webSocket, Response response) { webSocketRef.set(webSocket); responseRef.set(response); latch.countDown(); @@ -274,8 +287,7 @@ public final class WebSocketCallTest { } private static class EmptyWebSocketListener implements WebSocketListener { - @Override public void onOpen(WebSocket webSocket, Response response) - throws IOException { + @Override public void onOpen(WebSocket webSocket, Response response) { } @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type) diff --git a/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketCall.java b/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketCall.java index 766032096..46ee8a133 100644 --- a/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketCall.java +++ b/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketCall.java @@ -22,7 +22,6 @@ import com.squareup.okhttp.OkHttpClient; import com.squareup.okhttp.Request; import com.squareup.okhttp.Response; import com.squareup.okhttp.internal.Internal; -import com.squareup.okhttp.internal.NamedRunnable; import com.squareup.okhttp.internal.Util; import com.squareup.okhttp.internal.ws.RealWebSocket; import com.squareup.okhttp.internal.ws.WebSocketProtocol; @@ -169,13 +168,8 @@ public final class WebSocketCall { listener.onOpen(webSocket, response); - // Start a dedicated thread for reading the web socket. - new Thread(new NamedRunnable("OkHttp WebSocket reader %s", request.urlString()) { - @Override protected void execute() { - while (webSocket.readMessage()) { - } - } - }).start(); + while (webSocket.readMessage()) { + } } // Keep static so that the WebSocketCall instance can be garbage collected. diff --git a/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketListener.java b/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketListener.java index eba8e7661..8941b7443 100644 --- a/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketListener.java +++ b/okhttp-ws/src/main/java/com/squareup/okhttp/ws/WebSocketListener.java @@ -25,9 +25,15 @@ import static com.squareup.okhttp.ws.WebSocket.PayloadType; /** Listener for server-initiated messages on a connected {@link WebSocket}. */ public interface WebSocketListener { /** - * Called when the request has successfully been upgraded to a web socket. + * Called when the request has successfully been upgraded to a web socket. This method is called + * on the message reading thread to allow setting up any state before the + * {@linkplain #onMessage message}, {@linkplain #onPong pong}, and {@link #onClose close} + * callbacks start. + *

+ * Do not use this callback to write to the web socket. Start a new thread or use + * another thread in your application. */ - void onOpen(WebSocket webSocket, Response response) throws IOException; + void onOpen(WebSocket webSocket, Response response); /** * Called when the transport or protocol layer of this web socket errors during communication. diff --git a/samples/guide/src/main/java/com/squareup/okhttp/recipes/WebSocketEcho.java b/samples/guide/src/main/java/com/squareup/okhttp/recipes/WebSocketEcho.java index 0ab738cea..d439e99b6 100644 --- a/samples/guide/src/main/java/com/squareup/okhttp/recipes/WebSocketEcho.java +++ b/samples/guide/src/main/java/com/squareup/okhttp/recipes/WebSocketEcho.java @@ -7,6 +7,8 @@ import com.squareup.okhttp.ws.WebSocket; import com.squareup.okhttp.ws.WebSocketCall; import com.squareup.okhttp.ws.WebSocketListener; import java.io.IOException; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import okio.Buffer; import okio.BufferedSource; @@ -15,6 +17,8 @@ import static com.squareup.okhttp.ws.WebSocket.PayloadType.BINARY; import static com.squareup.okhttp.ws.WebSocket.PayloadType.TEXT; public final class WebSocketEcho implements WebSocketListener { + private final Executor writeExecutor = Executors.newSingleThreadExecutor(); + private void run() throws IOException { OkHttpClient client = new OkHttpClient(); @@ -27,21 +31,28 @@ public final class WebSocketEcho implements WebSocketListener { client.getDispatcher().getExecutorService().shutdown(); } - @Override public void onOpen(WebSocket webSocket, Response response) - throws IOException { - webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello...")); - webSocket.sendMessage(TEXT, new Buffer().writeUtf8("...World!")); - webSocket.sendMessage(BINARY, new Buffer().writeInt(0xdeadbeef)); - webSocket.close(1000, "Goodbye, World!"); + @Override public void onOpen(final WebSocket webSocket, Response response) { + writeExecutor.execute(new Runnable() { + @Override public void run() { + try { + webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello...")); + webSocket.sendMessage(TEXT, new Buffer().writeUtf8("...World!")); + webSocket.sendMessage(BINARY, new Buffer().writeInt(0xdeadbeef)); + webSocket.close(1000, "Goodbye, World!"); + } catch (IOException e) { + System.err.println("Unable to send messages: " + e.getMessage()); + } + } + }); } @Override public void onMessage(BufferedSource payload, PayloadType type) throws IOException { switch (type) { case TEXT: - System.out.println(payload.readUtf8()); + System.out.println("MESSAGE: " + payload.readUtf8()); break; case BINARY: - System.out.println(payload.readByteString().hex()); + System.out.println("MESSAGE: " + payload.readByteString().hex()); break; default: throw new IllegalStateException("Unknown payload type: " + type);