diff --git a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/Dispatcher.java b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/Dispatcher.java index 07f2f736f..ac6bac4a0 100644 --- a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/Dispatcher.java +++ b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/Dispatcher.java @@ -15,22 +15,20 @@ */ package com.squareup.okhttp.mockwebserver; -/** - * Handler for mock server requests. - */ +/** Handler for mock server requests. */ public abstract class Dispatcher { - /** - * Returns a response to satisfy {@code request}. This method may block (for - * instance, to wait on a CountdownLatch). - */ - public abstract MockResponse dispatch(RecordedRequest request) throws InterruptedException; + /** + * Returns a response to satisfy {@code request}. This method may block (for + * instance, to wait on a CountdownLatch). + */ + public abstract MockResponse dispatch(RecordedRequest request) throws InterruptedException; - /** - * Returns the socket policy of the next request. Default implementation - * returns {@link SocketPolicy#KEEP_OPEN}. Mischievous implementations can - * return other values to test HTTP edge cases. - */ - public SocketPolicy peekSocketPolicy() { - return SocketPolicy.KEEP_OPEN; - } + /** + * Returns the socket policy of the next request. Default implementation + * returns {@link SocketPolicy#KEEP_OPEN}. Mischievous implementations can + * return other values to test HTTP edge cases. + */ + public SocketPolicy peekSocketPolicy() { + return SocketPolicy.KEEP_OPEN; + } } diff --git a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockResponse.java b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockResponse.java index da0a87e7b..b073c11a7 100644 --- a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockResponse.java +++ b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockResponse.java @@ -13,10 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.squareup.okhttp.mockwebserver; -import static com.squareup.okhttp.mockwebserver.MockWebServer.ASCII; +import com.squareup.okhttp.internal.Util; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -26,215 +25,199 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; -/** - * A scripted response to be replayed by the mock web server. - */ +/** A scripted response to be replayed by the mock web server. */ public final class MockResponse implements Cloneable { - private static final String CHUNKED_BODY_HEADER = "Transfer-encoding: chunked"; + private static final String CHUNKED_BODY_HEADER = "Transfer-encoding: chunked"; - private String status = "HTTP/1.1 200 OK"; - private List headers = new ArrayList(); + private String status = "HTTP/1.1 200 OK"; + private List headers = new ArrayList(); - /** The response body content, or null if {@code bodyStream} is set. */ - private byte[] body; - /** The response body content, or null if {@code body} is set. */ - private InputStream bodyStream; + /** The response body content, or null if {@code bodyStream} is set. */ + private byte[] body; + /** The response body content, or null if {@code body} is set. */ + private InputStream bodyStream; - private int bytesPerSecond = Integer.MAX_VALUE; - private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN; + private int bytesPerSecond = Integer.MAX_VALUE; + private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN; - /** - * Creates a new mock response with an empty body. - */ - public MockResponse() { - setBody(new byte[0]); + /** Creates a new mock response with an empty body. */ + public MockResponse() { + setBody(new byte[0]); + } + + @Override public MockResponse clone() { + try { + MockResponse result = (MockResponse) super.clone(); + result.headers = new ArrayList(result.headers); + return result; + } catch (CloneNotSupportedException e) { + throw new AssertionError(); } + } - @Override public MockResponse clone() { - try { - MockResponse result = (MockResponse) super.clone(); - result.headers = new ArrayList(result.headers); - return result; - } catch (CloneNotSupportedException e) { - throw new AssertionError(); - } + /** Returns the HTTP response line, such as "HTTP/1.1 200 OK". */ + public String getStatus() { + return status; + } + + public MockResponse setResponseCode(int code) { + this.status = "HTTP/1.1 " + code + " OK"; + return this; + } + + public MockResponse setStatus(String status) { + this.status = status; + return this; + } + + /** Returns the HTTP headers, such as "Content-Length: 0". */ + public List getHeaders() { + return headers; + } + + /** + * Removes all HTTP headers including any "Content-Length" and + * "Transfer-encoding" headers that were added by default. + */ + public MockResponse clearHeaders() { + headers.clear(); + return this; + } + + /** + * Adds {@code header} as an HTTP header. For well-formed HTTP {@code header} + * should contain a name followed by a colon and a value. + */ + public MockResponse addHeader(String header) { + headers.add(header); + return this; + } + + /** + * Adds a new header with the name and value. This may be used to add multiple + * headers with the same name. + */ + public MockResponse addHeader(String name, Object value) { + return addHeader(name + ": " + String.valueOf(value)); + } + + /** + * Removes all headers named {@code name}, then adds a new header with the + * name and value. + */ + public MockResponse setHeader(String name, Object value) { + removeHeader(name); + return addHeader(name, value); + } + + /** Removes all headers named {@code name}. */ + public MockResponse removeHeader(String name) { + name += ":"; + for (Iterator i = headers.iterator(); i.hasNext(); ) { + String header = i.next(); + if (name.regionMatches(true, 0, header, 0, name.length())) { + i.remove(); + } } + return this; + } - /** - * Returns the HTTP response line, such as "HTTP/1.1 200 OK". - */ - public String getStatus() { - return status; + /** Returns the raw HTTP payload, or null if this response is streamed. */ + public byte[] getBody() { + return body; + } + + /** Returns an input stream containing the raw HTTP payload. */ + InputStream getBodyStream() { + return bodyStream != null ? bodyStream : new ByteArrayInputStream(body); + } + + public MockResponse setBody(byte[] body) { + setHeader("Content-Length", body.length); + this.body = body; + this.bodyStream = null; + return this; + } + + public MockResponse setBody(InputStream bodyStream, long bodyLength) { + setHeader("Content-Length", bodyLength); + this.body = null; + this.bodyStream = bodyStream; + return this; + } + + /** Sets the response body to the UTF-8 encoded bytes of {@code body}. */ + public MockResponse setBody(String body) { + try { + return setBody(body.getBytes("UTF-8")); + } catch (UnsupportedEncodingException e) { + throw new AssertionError(); } + } - public MockResponse setResponseCode(int code) { - this.status = "HTTP/1.1 " + code + " OK"; - return this; + /** + * Sets the response body to {@code body}, chunked every {@code maxChunkSize} + * bytes. + */ + public MockResponse setChunkedBody(byte[] body, int maxChunkSize) { + removeHeader("Content-Length"); + headers.add(CHUNKED_BODY_HEADER); + + try { + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + int pos = 0; + while (pos < body.length) { + int chunkSize = Math.min(body.length - pos, maxChunkSize); + bytesOut.write(Integer.toHexString(chunkSize).getBytes(Util.US_ASCII)); + bytesOut.write("\r\n".getBytes(Util.US_ASCII)); + bytesOut.write(body, pos, chunkSize); + bytesOut.write("\r\n".getBytes(Util.US_ASCII)); + pos += chunkSize; + } + bytesOut.write("0\r\n\r\n".getBytes(Util.US_ASCII)); // Last chunk + empty trailer + crlf. + + this.body = bytesOut.toByteArray(); + return this; + } catch (IOException e) { + throw new AssertionError(); // In-memory I/O doesn't throw IOExceptions. } + } - public MockResponse setStatus(String status) { - this.status = status; - return this; + /** + * Sets the response body to the UTF-8 encoded bytes of {@code body}, chunked + * every {@code maxChunkSize} bytes. + */ + public MockResponse setChunkedBody(String body, int maxChunkSize) { + try { + return setChunkedBody(body.getBytes("UTF-8"), maxChunkSize); + } catch (UnsupportedEncodingException e) { + throw new AssertionError(); } + } - /** - * Returns the HTTP headers, such as "Content-Length: 0". - */ - public List getHeaders() { - return headers; - } + public SocketPolicy getSocketPolicy() { + return socketPolicy; + } - /** - * Removes all HTTP headers including any "Content-Length" and - * "Transfer-encoding" headers that were added by default. - */ - public MockResponse clearHeaders() { - headers.clear(); - return this; - } + public MockResponse setSocketPolicy(SocketPolicy socketPolicy) { + this.socketPolicy = socketPolicy; + return this; + } - /** - * Adds {@code header} as an HTTP header. For well-formed HTTP {@code - * header} should contain a name followed by a colon and a value. - */ - public MockResponse addHeader(String header) { - headers.add(header); - return this; - } + public int getBytesPerSecond() { + return bytesPerSecond; + } - /** - * Adds a new header with the name and value. This may be used to add - * multiple headers with the same name. - */ - public MockResponse addHeader(String name, Object value) { - return addHeader(name + ": " + String.valueOf(value)); - } + /** + * Set simulated network speed, in bytes per second. This applies to the + * response body only; response headers are not throttled. + */ + public MockResponse setBytesPerSecond(int bytesPerSecond) { + this.bytesPerSecond = bytesPerSecond; + return this; + } - /** - * Removes all headers named {@code name}, then adds a new header with the - * name and value. - */ - public MockResponse setHeader(String name, Object value) { - removeHeader(name); - return addHeader(name, value); - } - - /** - * Removes all headers named {@code name}. - */ - public MockResponse removeHeader(String name) { - name += ":"; - for (Iterator i = headers.iterator(); i.hasNext(); ) { - String header = i.next(); - if (name.regionMatches(true, 0, header, 0, name.length())) { - i.remove(); - } - } - return this; - } - - /** - * Returns the raw HTTP payload, or null if this response is streamed. - */ - public byte[] getBody() { - return body; - } - - /** - * Returns an input stream containing the raw HTTP payload. - */ - InputStream getBodyStream() { - return bodyStream != null ? bodyStream : new ByteArrayInputStream(body); - } - - public MockResponse setBody(byte[] body) { - setHeader("Content-Length", body.length); - this.body = body; - this.bodyStream = null; - return this; - } - - public MockResponse setBody(InputStream bodyStream, long bodyLength) { - setHeader("Content-Length", bodyLength); - this.body = null; - this.bodyStream = bodyStream; - return this; - } - - /** - * Sets the response body to the UTF-8 encoded bytes of {@code body}. - */ - public MockResponse setBody(String body) { - try { - return setBody(body.getBytes("UTF-8")); - } catch (UnsupportedEncodingException e) { - throw new AssertionError(); - } - } - - /** - * Sets the response body to {@code body}, chunked every {@code - * maxChunkSize} bytes. - */ - public MockResponse setChunkedBody(byte[] body, int maxChunkSize) { - removeHeader("Content-Length"); - headers.add(CHUNKED_BODY_HEADER); - - try { - ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); - int pos = 0; - while (pos < body.length) { - int chunkSize = Math.min(body.length - pos, maxChunkSize); - bytesOut.write(Integer.toHexString(chunkSize).getBytes(ASCII)); - bytesOut.write("\r\n".getBytes(ASCII)); - bytesOut.write(body, pos, chunkSize); - bytesOut.write("\r\n".getBytes(ASCII)); - pos += chunkSize; - } - bytesOut.write("0\r\n\r\n".getBytes(ASCII)); // last chunk + empty trailer + crlf - - this.body = bytesOut.toByteArray(); - return this; - } catch (IOException e) { - throw new AssertionError(); // In-memory I/O doesn't throw IOExceptions. - } - } - - /** - * Sets the response body to the UTF-8 encoded bytes of {@code body}, - * chunked every {@code maxChunkSize} bytes. - */ - public MockResponse setChunkedBody(String body, int maxChunkSize) { - try { - return setChunkedBody(body.getBytes("UTF-8"), maxChunkSize); - } catch (UnsupportedEncodingException e) { - throw new AssertionError(); - } - } - - public SocketPolicy getSocketPolicy() { - return socketPolicy; - } - - public MockResponse setSocketPolicy(SocketPolicy socketPolicy) { - this.socketPolicy = socketPolicy; - return this; - } - - public int getBytesPerSecond() { - return bytesPerSecond; - } - - /** - * Set simulated network speed, in bytes per second. This applies to the - * response body only; response headers are not throttled. - */ - public MockResponse setBytesPerSecond(int bytesPerSecond) { - this.bytesPerSecond = bytesPerSecond; - return this; - } - - @Override public String toString() { - return status; - } + @Override public String toString() { + return status; + } } diff --git a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java index 2e9b6ef01..6780e1332 100644 --- a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java +++ b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/MockWebServer.java @@ -1,5 +1,6 @@ /* * Copyright (C) 2011 Google Inc. + * Copyright (C) 2013 Square, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,677 +69,636 @@ import static com.squareup.okhttp.mockwebserver.SocketPolicy.FAIL_HANDSHAKE; * replays them upon request in sequence. */ public final class MockWebServer { - private static final byte[] NPN_PROTOCOLS = { - 6, 's', 'p', 'd', 'y', '/', '3', - 8, 'h', 't', 't', 'p', '/', '1', '.', '1' - }; - private static final byte[] SPDY3 = new byte[] { - 's', 'p', 'd', 'y', '/', '3' - }; - private static final byte[] HTTP_11 = new byte[] { - 'h', 't', 't', 'p', '/', '1', '.', '1' - }; + private static final byte[] NPN_PROTOCOLS = { + 6, 's', 'p', 'd', 'y', '/', '3', + 8, 'h', 't', 't', 'p', '/', '1', '.', '1' + }; + private static final byte[] SPDY3 = new byte[] { + 's', 'p', 'd', 'y', '/', '3' + }; + private static final byte[] HTTP_11 = new byte[] { + 'h', 't', 't', 'p', '/', '1', '.', '1' + }; - static final String ASCII = "US-ASCII"; - - private static final Logger logger = Logger.getLogger(MockWebServer.class.getName()); - private final BlockingQueue requestQueue - = new LinkedBlockingQueue(); - /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */ - private final Map openClientSockets = new ConcurrentHashMap(); - private final Map openSpdyConnections - = new ConcurrentHashMap(); - private final AtomicInteger requestCount = new AtomicInteger(); - private int bodyLimit = Integer.MAX_VALUE; - private ServerSocket serverSocket; - private SSLSocketFactory sslSocketFactory; - private ExecutorService executor; - private boolean tunnelProxy; - private Dispatcher dispatcher = new QueueDispatcher(); - - private int port = -1; - private boolean npnEnabled = true; - - public int getPort() { - if (port == -1) { - throw new IllegalStateException("Cannot retrieve port before calling play()"); - } - return port; + private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() { + @Override public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateException(); } - public String getHostName() { + @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { + throw new AssertionError(); + } + + @Override public X509Certificate[] getAcceptedIssuers() { + throw new AssertionError(); + } + }; + + private static final Logger logger = Logger.getLogger(MockWebServer.class.getName()); + + private final BlockingQueue requestQueue = + new LinkedBlockingQueue(); + + /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */ + private final Map openClientSockets = new ConcurrentHashMap(); + private final Map openSpdyConnections + = new ConcurrentHashMap(); + private final AtomicInteger requestCount = new AtomicInteger(); + private int bodyLimit = Integer.MAX_VALUE; + private ServerSocket serverSocket; + private SSLSocketFactory sslSocketFactory; + private ExecutorService executor; + private boolean tunnelProxy; + private Dispatcher dispatcher = new QueueDispatcher(); + + private int port = -1; + private boolean npnEnabled = true; + + public int getPort() { + if (port == -1) throw new IllegalStateException("Cannot retrieve port before calling play()"); + return port; + } + + public String getHostName() { + try { + return InetAddress.getLocalHost().getHostName(); + } catch (UnknownHostException e) { + throw new AssertionError(e); + } + } + + public Proxy toProxyAddress() { + return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort())); + } + + /** + * Returns a URL for connecting to this server. + * @param path the request path, such as "/". + */ + public URL getUrl(String path) { + try { + return sslSocketFactory != null + ? new URL("https://" + getHostName() + ":" + getPort() + path) + : new URL("http://" + getHostName() + ":" + getPort() + path); + } catch (MalformedURLException e) { + throw new AssertionError(e); + } + } + + /** + * Returns a cookie domain for this server. This returns the server's + * non-loopback host name if it is known. Otherwise this returns ".local" for + * this server's loopback name. + */ + public String getCookieDomain() { + String hostName = getHostName(); + return hostName.contains(".") ? hostName : ".local"; + } + + /** + * Sets the number of bytes of the POST body to keep in memory to the given + * limit. + */ + public void setBodyLimit(int maxBodyLength) { + this.bodyLimit = maxBodyLength; + } + + /** + * Sets whether NPN is used on incoming HTTPS connections to negotiate a + * transport like HTTP/1.1 or SPDY/3. Call this method to disable NPN and + * SPDY. + */ + public void setNpnEnabled(boolean npnEnabled) { + this.npnEnabled = npnEnabled; + } + + /** + * Serve requests with HTTPS rather than otherwise. + * @param tunnelProxy true to expect the HTTP CONNECT method before + * negotiating TLS. + */ + public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { + this.sslSocketFactory = sslSocketFactory; + this.tunnelProxy = tunnelProxy; + } + + /** + * Awaits the next HTTP request, removes it, and returns it. Callers should + * use this to verify the request was sent as intended. + */ + public RecordedRequest takeRequest() throws InterruptedException { + return requestQueue.take(); + } + + /** + * Returns the number of HTTP requests received thus far by this server. This + * may exceed the number of HTTP connections when connection reuse is in + * practice. + */ + public int getRequestCount() { + return requestCount.get(); + } + + /** + * Scripts {@code response} to be returned to a request made in sequence. The + * first request is served by the first enqueued response; the second request + * by the second enqueued response; and so on. + * + * @throws ClassCastException if the default dispatcher has been replaced + * with {@link #setDispatcher(Dispatcher)}. + */ + public void enqueue(MockResponse response) { + ((QueueDispatcher) dispatcher).enqueueResponse(response.clone()); + } + + /** Equivalent to {@code play(0)}. */ + public void play() throws IOException { + play(0); + } + + /** + * Starts the server, serves all enqueued requests, and shuts the server down. + * + * @param port the port to listen to, or 0 for any available port. Automated + * tests should always use port 0 to avoid flakiness when a specific port + * is unavailable. + */ + public void play(int port) throws IOException { + if (executor != null) throw new IllegalStateException("play() already called"); + executor = Executors.newCachedThreadPool(); + serverSocket = new ServerSocket(port); + serverSocket.setReuseAddress(true); + + this.port = serverSocket.getLocalPort(); + executor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() { + public void run() { try { - return InetAddress.getLocalHost().getHostName(); - } catch (UnknownHostException e) { - throw new AssertionError(e); - } - } - - public Proxy toProxyAddress() { - return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort())); - } - - /** - * Returns a URL for connecting to this server. - * - * @param path the request path, such as "/". - */ - public URL getUrl(String path) { - try { - return sslSocketFactory != null - ? new URL("https://" + getHostName() + ":" + getPort() + path) - : new URL("http://" + getHostName() + ":" + getPort() + path); - } catch (MalformedURLException e) { - throw new AssertionError(e); - } - } - - /** - * Returns a cookie domain for this server. This returns the server's - * non-loopback host name if it is known. Otherwise this returns ".local" - * for this server's loopback name. - */ - public String getCookieDomain() { - String hostName = getHostName(); - return hostName.contains(".") ? hostName : ".local"; - } - - /** - * Sets the number of bytes of the POST body to keep in memory to the given - * limit. - */ - public void setBodyLimit(int maxBodyLength) { - this.bodyLimit = maxBodyLength; - } - - /** - * Sets whether NPN is used on incoming HTTPS connections to negotiate a - * transport like HTTP/1.1 or SPDY/3. Call this method to disable NPN and - * SPDY. - */ - public void setNpnEnabled(boolean npnEnabled) { - this.npnEnabled = npnEnabled; - } - - /** - * Serve requests with HTTPS rather than otherwise. - * - * @param tunnelProxy whether to expect the HTTP CONNECT method before - * negotiating TLS. - */ - public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { - this.sslSocketFactory = sslSocketFactory; - this.tunnelProxy = tunnelProxy; - } - - /** - * Awaits the next HTTP request, removes it, and returns it. Callers should - * use this to verify the request sent was as intended. - */ - public RecordedRequest takeRequest() throws InterruptedException { - return requestQueue.take(); - } - - /** - * Returns the number of HTTP requests received thus far by this server. - * This may exceed the number of HTTP connections when connection reuse is - * in practice. - */ - public int getRequestCount() { - return requestCount.get(); - } - - /** - * Scripts {@code response} to be returned to a request made in sequence. - * The first request is served by the first enqueued response; the second - * request by the second enqueued response; and so on. - * - * @throws ClassCastException if the default dispatcher has been replaced - * with {@link #setDispatcher(Dispatcher)}. - */ - public void enqueue(MockResponse response) { - ((QueueDispatcher) dispatcher).enqueueResponse(response.clone()); - } - - /** - * Equivalent to {@code play(0)}. - */ - public void play() throws IOException { - play(0); - } - - /** - * Starts the server, serves all enqueued requests, and shuts the server - * down. - * - * @param port the port to listen to, or 0 for any available port. - * Automated tests should always use port 0 to avoid flakiness when a - * specific port is unavailable. - */ - public void play(int port) throws IOException { - if (executor != null) { - throw new IllegalStateException("play() already called"); - } - executor = Executors.newCachedThreadPool(); - serverSocket = new ServerSocket(port); - serverSocket.setReuseAddress(true); - - this.port = serverSocket.getLocalPort(); - executor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() { - public void run() { - try { - acceptConnections(); - } catch (Throwable e) { - logger.log(Level.WARNING, "MockWebServer connection failed", e); - } - - /* - * This gnarly block of code will release all sockets and - * all thread, even if any close fails. - */ - try { - serverSocket.close(); - } catch (Throwable e) { - logger.log(Level.WARNING, "MockWebServer server socket close failed", e); - } - for (Iterator s = openClientSockets.keySet().iterator(); s.hasNext(); ) { - try { - s.next().close(); - s.remove(); - } catch (Throwable e) { - logger.log(Level.WARNING, "MockWebServer socket close failed", e); - } - } - for (Iterator s = openSpdyConnections.keySet().iterator(); - s.hasNext(); ) { - try { - s.next().close(); - s.remove(); - } catch (Throwable e) { - logger.log(Level.WARNING, "MockWebServer SPDY connection close failed", e); - } - } - try { - executor.shutdown(); - } catch (Throwable e) { - logger.log(Level.WARNING, "MockWebServer executor shutdown failed", e); - } - } - - private void acceptConnections() throws Exception { - while (true) { - Socket socket; - try { - socket = serverSocket.accept(); - } catch (SocketException e) { - return; - } - final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); - if (socketPolicy == DISCONNECT_AT_START) { - dispatchBookkeepingRequest(0, socket); - socket.close(); - } else { - openClientSockets.put(socket, true); - serveConnection(socket); - } - } - } - })); - } - - public void shutdown() throws IOException { - if (serverSocket != null) { - serverSocket.close(); // should cause acceptConnections() to break out - } - } - - private void serveConnection(final Socket raw) { - String name = "MockWebServer-" + raw.getRemoteSocketAddress(); - executor.execute(namedRunnable(name, new Runnable() { - int sequenceNumber = 0; - - public void run() { - try { - processConnection(); - } catch (Exception e) { - logger.log(Level.WARNING, "MockWebServer connection failed", e); - } - } - - public void processConnection() throws Exception { - Transport transport = Transport.HTTP_11; - Socket socket; - if (sslSocketFactory != null) { - if (tunnelProxy) { - createTunnel(); - } - final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); - if (socketPolicy == FAIL_HANDSHAKE) { - dispatchBookkeepingRequest(sequenceNumber, raw); - processHandshakeFailure(raw); - return; - } - socket = sslSocketFactory.createSocket( - raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); - SSLSocket sslSocket = (SSLSocket) socket; - sslSocket.setUseClientMode(false); - openClientSockets.put(socket, true); - - if (npnEnabled) { - Platform.get().setNpnProtocols(sslSocket, NPN_PROTOCOLS); - } - - sslSocket.startHandshake(); - - if (npnEnabled) { - byte[] selectedProtocol = Platform.get().getNpnSelectedProtocol(sslSocket); - if (selectedProtocol == null || Arrays.equals(selectedProtocol, HTTP_11)) { - transport = Transport.HTTP_11; - } else if (Arrays.equals(selectedProtocol, SPDY3)) { - transport = Transport.SPDY_3; - } else { - throw new IllegalStateException("Unexpected transport: " - + new String(selectedProtocol, Util.US_ASCII)); - } - } - openClientSockets.remove(raw); - } else { - socket = raw; - } - - if (transport == Transport.SPDY_3) { - SpdySocketHandler spdySocketHandler = new SpdySocketHandler(socket); - SpdyConnection spdyConnection = new SpdyConnection.Builder(false, socket) - .handler(spdySocketHandler) - .build(); - openSpdyConnections.put(spdyConnection, Boolean.TRUE); - openClientSockets.remove(socket); - return; - } - - InputStream in = new BufferedInputStream(socket.getInputStream()); - OutputStream out = new BufferedOutputStream(socket.getOutputStream()); - - while (processOneRequest(socket, in, out)) { - } - - if (sequenceNumber == 0) { - logger.warning("MockWebServer connection didn't make a request"); - } - - in.close(); - out.close(); - socket.close(); - openClientSockets.remove(socket); - } - - /** - * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response - * is dispatched. - */ - private void createTunnel() throws IOException, InterruptedException { - while (true) { - final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); - if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) { - throw new IllegalStateException("Tunnel without any CONNECT!"); - } - if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) { - return; - } - } - } - - /** - * Reads a request and writes its response. Returns true if a request - * was processed. - */ - private boolean processOneRequest(Socket socket, InputStream in, OutputStream out) - throws IOException, InterruptedException { - RecordedRequest request = readRequest(socket, in, out, sequenceNumber); - if (request == null) { - return false; - } - requestCount.incrementAndGet(); - requestQueue.add(request); - MockResponse response = dispatcher.dispatch(request); - writeResponse(out, response); - if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) { - in.close(); - out.close(); - } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) { - socket.shutdownInput(); - } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) { - socket.shutdownOutput(); - } - logger.info("Received request: " + request + " and responded: " + response); - sequenceNumber++; - return true; - } - })); - } - - private void processHandshakeFailure(Socket raw) throws Exception { - X509TrustManager untrusted = new X509TrustManager() { - @Override public void checkClientTrusted(X509Certificate[] chain, String authType) - throws CertificateException { - throw new CertificateException(); - } - @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { - throw new AssertionError(); - } - @Override public X509Certificate[] getAcceptedIssuers() { - throw new AssertionError(); - } - }; - SSLContext context = SSLContext.getInstance("TLS"); - context.init(null, new TrustManager[] { untrusted }, new SecureRandom()); - SSLSocketFactory sslSocketFactory = context.getSocketFactory(); - SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket( - raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); - try { - socket.startHandshake(); // we're testing a handshake failure - throw new AssertionError(); - } catch (IOException expected) { - } - socket.close(); - } - - private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket) - throws InterruptedException { - requestCount.incrementAndGet(); - dispatcher.dispatch( - new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket)); - } - - /** - * @param sequenceNumber the index of this request on this connection. - */ - private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out, - int sequenceNumber) throws IOException { - String request; - try { - request = readAsciiUntilCrlf(in); - } catch (IOException streamIsClosed) { - return null; // no request because we closed the stream - } - if (request.length() == 0) { - return null; // no request because the stream is exhausted + acceptConnections(); + } catch (Throwable e) { + logger.log(Level.WARNING, "MockWebServer connection failed", e); } - List headers = new ArrayList(); - long contentLength = -1; - boolean chunked = false; - boolean expectContinue = false; - String header; - while ((header = readAsciiUntilCrlf(in)).length() != 0) { - headers.add(header); - String lowercaseHeader = header.toLowerCase(Locale.US); - if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { - contentLength = Long.parseLong(header.substring(15).trim()); - } - if (lowercaseHeader.startsWith("transfer-encoding:") - && lowercaseHeader.substring(18).trim().equals("chunked")) { - chunked = true; - } - if (lowercaseHeader.startsWith("expect:") - && lowercaseHeader.substring(7).trim().equals("100-continue")) { - expectContinue = true; - } + // This gnarly block of code will release all sockets and all thread, + // even if any close fails. + Util.closeQuietly(serverSocket); + for (Iterator s = openClientSockets.keySet().iterator(); s.hasNext(); ) { + Util.closeQuietly(s.next()); + s.remove(); } - - if (expectContinue) { - out.write(("HTTP/1.1 100 Continue\r\n").getBytes(ASCII)); - out.write(("Content-Length: 0\r\n").getBytes(ASCII)); - out.write(("\r\n").getBytes(ASCII)); - out.flush(); + for (Iterator s = openSpdyConnections.keySet().iterator(); s.hasNext(); ) { + Util.closeQuietly(s.next()); + s.remove(); } + executor.shutdown(); + } - boolean hasBody = false; - TruncatingOutputStream requestBody = new TruncatingOutputStream(); - List chunkSizes = new ArrayList(); - if (contentLength != -1) { - hasBody = true; - transfer(contentLength, in, requestBody); - } else if (chunked) { - hasBody = true; - while (true) { - int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); - if (chunkSize == 0) { - readEmptyLine(in); - break; - } - chunkSizes.add(chunkSize); - transfer(chunkSize, in, requestBody); - readEmptyLine(in); - } - } - - if (request.startsWith("OPTIONS ") || request.startsWith("GET ") - || request.startsWith("HEAD ") || request.startsWith("DELETE ") - || request.startsWith("TRACE ") || request.startsWith("CONNECT ")) { - if (hasBody) { - throw new IllegalArgumentException("Request must not have a body: " + request); - } - } else if (!request.startsWith("POST ") && !request.startsWith("PUT ")) { - throw new UnsupportedOperationException("Unexpected method: " + request); - } - - return new RecordedRequest(request, headers, chunkSizes, - requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket); - } - - private void writeResponse(OutputStream out, MockResponse response) throws IOException { - out.write((response.getStatus() + "\r\n").getBytes(ASCII)); - for (String header : response.getHeaders()) { - out.write((header + "\r\n").getBytes(ASCII)); - } - out.write(("\r\n").getBytes(ASCII)); - out.flush(); - - final InputStream in = response.getBodyStream(); - if (in == null) { - return; - } - final int bytesPerSecond = response.getBytesPerSecond(); - - // Stream data in MTU-sized increments, with a minimum of one packet per second. - final byte[] buffer = bytesPerSecond >= 1452 - ? new byte[1452] - : new byte[bytesPerSecond]; - final long delayMs; - if (bytesPerSecond == Integer.MAX_VALUE) { - delayMs = 0; - } else { - delayMs = (1000 * buffer.length) / bytesPerSecond; - } - - int read; - long sinceDelay = 0; - while ((read = in.read(buffer)) != -1) { - out.write(buffer, 0, read); - out.flush(); - - sinceDelay += read; - if (sinceDelay >= buffer.length && delayMs > 0) { - sinceDelay %= buffer.length; - try { - Thread.sleep(delayMs); - } catch (InterruptedException e) { - throw new AssertionError(); - } - } - } - } - - /** - * Transfer bytes from {@code in} to {@code out} until either {@code length} - * bytes have been transferred or {@code in} is exhausted. - */ - private void transfer(long length, InputStream in, OutputStream out) throws IOException { - byte[] buffer = new byte[1024]; - while (length > 0) { - int count = in.read(buffer, 0, (int) Math.min(buffer.length, length)); - if (count == -1) { - return; - } - out.write(buffer, 0, count); - length -= count; - } - } - - /** - * Returns the text from {@code in} until the next "\r\n", or null if - * {@code in} is exhausted. - */ - private String readAsciiUntilCrlf(InputStream in) throws IOException { - StringBuilder builder = new StringBuilder(); + private void acceptConnections() throws Exception { while (true) { - int c = in.read(); - if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { - builder.deleteCharAt(builder.length() - 1); - return builder.toString(); - } else if (c == -1) { - return builder.toString(); + Socket socket; + try { + socket = serverSocket.accept(); + } catch (SocketException e) { + return; + } + SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); + if (socketPolicy == DISCONNECT_AT_START) { + dispatchBookkeepingRequest(0, socket); + socket.close(); + } else { + openClientSockets.put(socket, true); + serveConnection(socket); + } + } + } + })); + } + + public void shutdown() throws IOException { + if (serverSocket != null) { + serverSocket.close(); // Should cause acceptConnections() to break out. + } + } + + private void serveConnection(final Socket raw) { + String name = "MockWebServer-" + raw.getRemoteSocketAddress(); + executor.execute(namedRunnable(name, new Runnable() { + int sequenceNumber = 0; + + public void run() { + try { + processConnection(); + } catch (Exception e) { + logger.log(Level.WARNING, "MockWebServer connection failed", e); + } + } + + public void processConnection() throws Exception { + Transport transport = Transport.HTTP_11; + Socket socket; + if (sslSocketFactory != null) { + if (tunnelProxy) { + createTunnel(); + } + SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); + if (socketPolicy == FAIL_HANDSHAKE) { + dispatchBookkeepingRequest(sequenceNumber, raw); + processHandshakeFailure(raw); + return; + } + socket = sslSocketFactory.createSocket( + raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); + SSLSocket sslSocket = (SSLSocket) socket; + sslSocket.setUseClientMode(false); + openClientSockets.put(socket, true); + + if (npnEnabled) { + Platform.get().setNpnProtocols(sslSocket, NPN_PROTOCOLS); + } + + sslSocket.startHandshake(); + + if (npnEnabled) { + byte[] selectedProtocol = Platform.get().getNpnSelectedProtocol(sslSocket); + if (selectedProtocol == null || Arrays.equals(selectedProtocol, HTTP_11)) { + transport = Transport.HTTP_11; + } else if (Arrays.equals(selectedProtocol, SPDY3)) { + transport = Transport.SPDY_3; } else { - builder.append((char) c); + throw new IllegalStateException( + "Unexpected transport: " + new String(selectedProtocol, Util.US_ASCII)); } + } + openClientSockets.remove(raw); + } else { + socket = raw; } + + if (transport == Transport.SPDY_3) { + SpdySocketHandler spdySocketHandler = new SpdySocketHandler(socket); + SpdyConnection spdyConnection = new SpdyConnection.Builder(false, socket) + .handler(spdySocketHandler) + .build(); + openSpdyConnections.put(spdyConnection, Boolean.TRUE); + openClientSockets.remove(socket); + return; + } + + InputStream in = new BufferedInputStream(socket.getInputStream()); + OutputStream out = new BufferedOutputStream(socket.getOutputStream()); + + while (processOneRequest(socket, in, out)) { + } + + if (sequenceNumber == 0) { + logger.warning("MockWebServer connection didn't make a request"); + } + + in.close(); + out.close(); + socket.close(); + openClientSockets.remove(socket); + } + + /** + * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response is + * dispatched. + */ + private void createTunnel() throws IOException, InterruptedException { + while (true) { + SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); + if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) { + throw new IllegalStateException("Tunnel without any CONNECT!"); + } + if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return; + } + } + + /** + * Reads a request and writes its response. Returns true if a request was + * processed. + */ + private boolean processOneRequest(Socket socket, InputStream in, OutputStream out) + throws IOException, InterruptedException { + RecordedRequest request = readRequest(socket, in, out, sequenceNumber); + if (request == null) return false; + requestCount.incrementAndGet(); + requestQueue.add(request); + MockResponse response = dispatcher.dispatch(request); + writeResponse(out, response); + if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) { + in.close(); + out.close(); + } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) { + socket.shutdownInput(); + } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) { + socket.shutdownOutput(); + } + logger.info("Received request: " + request + " and responded: " + response); + sequenceNumber++; + return true; + } + })); + } + + private void processHandshakeFailure(Socket raw) throws Exception { + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom()); + SSLSocketFactory sslSocketFactory = context.getSocketFactory(); + SSLSocket socket = + (SSLSocket) sslSocketFactory.createSocket(raw, raw.getInetAddress().getHostAddress(), + raw.getPort(), true); + try { + socket.startHandshake(); // we're testing a handshake failure + throw new AssertionError(); + } catch (IOException expected) { + } + socket.close(); + } + + private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket) + throws InterruptedException { + requestCount.incrementAndGet(); + dispatcher.dispatch(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket)); + } + + /** @param sequenceNumber the index of this request on this connection. */ + private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out, + int sequenceNumber) throws IOException { + String request; + try { + request = readAsciiUntilCrlf(in); + } catch (IOException streamIsClosed) { + return null; // no request because we closed the stream + } + if (request.length() == 0) { + return null; // no request because the stream is exhausted } - private void readEmptyLine(InputStream in) throws IOException { - String line = readAsciiUntilCrlf(in); - if (line.length() != 0) { - throw new IllegalStateException("Expected empty but was: " + line); - } + List headers = new ArrayList(); + long contentLength = -1; + boolean chunked = false; + boolean expectContinue = false; + String header; + while ((header = readAsciiUntilCrlf(in)).length() != 0) { + headers.add(header); + String lowercaseHeader = header.toLowerCase(Locale.US); + if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { + contentLength = Long.parseLong(header.substring(15).trim()); + } + if (lowercaseHeader.startsWith("transfer-encoding:") + && lowercaseHeader.substring(18).trim().equals("chunked")) { + chunked = true; + } + if (lowercaseHeader.startsWith("expect:") + && lowercaseHeader.substring(7).trim().equals("100-continue")) { + expectContinue = true; + } } - /** - * Sets the dispatcher used to match incoming requests to mock responses. - * The default dispatcher simply serves a fixed sequence of responses from - * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the - * response based on timing or the content of the request. - */ - public void setDispatcher(Dispatcher dispatcher) { - if (dispatcher == null) { - throw new NullPointerException(); - } - this.dispatcher = dispatcher; + if (expectContinue) { + out.write(("HTTP/1.1 100 Continue\r\n").getBytes(Util.US_ASCII)); + out.write(("Content-Length: 0\r\n").getBytes(Util.US_ASCII)); + out.write(("\r\n").getBytes(Util.US_ASCII)); + out.flush(); } - /** - * An output stream that drops data after bodyLimit bytes. - */ - private class TruncatingOutputStream extends ByteArrayOutputStream { - private long numBytesReceived = 0; - @Override public void write(byte[] buffer, int offset, int len) { - numBytesReceived += len; - super.write(buffer, offset, Math.min(len, bodyLimit - count)); - } - @Override public void write(int oneByte) { - numBytesReceived++; - if (count < bodyLimit) { - super.write(oneByte); - } + boolean hasBody = false; + TruncatingOutputStream requestBody = new TruncatingOutputStream(); + List chunkSizes = new ArrayList(); + if (contentLength != -1) { + hasBody = true; + transfer(contentLength, in, requestBody); + } else if (chunked) { + hasBody = true; + while (true) { + int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); + if (chunkSize == 0) { + readEmptyLine(in); + break; } + chunkSizes.add(chunkSize); + transfer(chunkSize, in, requestBody); + readEmptyLine(in); + } } - private static Runnable namedRunnable(final String name, final Runnable runnable) { - return new Runnable() { - public void run() { - String originalName = Thread.currentThread().getName(); - Thread.currentThread().setName(name); - try { - runnable.run(); - } finally { - Thread.currentThread().setName(originalName); - } - } - }; + if (request.startsWith("OPTIONS ") + || request.startsWith("GET ") + || request.startsWith("HEAD ") + || request.startsWith("DELETE ") + || request.startsWith("TRACE ") + || request.startsWith("CONNECT ")) { + if (hasBody) { + throw new IllegalArgumentException("Request must not have a body: " + request); + } + } else if (!request.startsWith("POST ") && !request.startsWith("PUT ")) { + throw new UnsupportedOperationException("Unexpected method: " + request); } - /** Processes HTTP requests layered over SPDY/3. */ - private class SpdySocketHandler implements IncomingStreamHandler { - private final Socket socket; + return new RecordedRequest(request, headers, chunkSizes, requestBody.numBytesReceived, + requestBody.toByteArray(), sequenceNumber, socket); + } - private SpdySocketHandler(Socket socket) { - this.socket = socket; + private void writeResponse(OutputStream out, MockResponse response) throws IOException { + out.write((response.getStatus() + "\r\n").getBytes(Util.US_ASCII)); + for (String header : response.getHeaders()) { + out.write((header + "\r\n").getBytes(Util.US_ASCII)); + } + out.write(("\r\n").getBytes(Util.US_ASCII)); + out.flush(); + + InputStream in = response.getBodyStream(); + if (in == null) return; + int bytesPerSecond = response.getBytesPerSecond(); + + // Stream data in MTU-sized increments, with a minimum of one packet per second. + byte[] buffer = bytesPerSecond >= 1452 ? new byte[1452] : new byte[bytesPerSecond]; + long delayMs = bytesPerSecond == Integer.MAX_VALUE + ? 0 + : (1000 * buffer.length) / bytesPerSecond; + + int read; + long sinceDelay = 0; + while ((read = in.read(buffer)) != -1) { + out.write(buffer, 0, read); + out.flush(); + + sinceDelay += read; + if (sinceDelay >= buffer.length && delayMs > 0) { + sinceDelay %= buffer.length; + try { + Thread.sleep(delayMs); + } catch (InterruptedException e) { + throw new AssertionError(); } + } + } + } - @Override public void receive(final SpdyStream stream) throws IOException { - RecordedRequest request = readRequest(stream); - requestQueue.add(request); - MockResponse response; - try { - response = dispatcher.dispatch(request); - } catch (InterruptedException e) { - throw new AssertionError(e); - } - writeResponse(stream, response); - logger.info("Received request: " + request + " and responded: " + response); - } + /** + * Transfer bytes from {@code in} to {@code out} until either {@code length} + * bytes have been transferred or {@code in} is exhausted. + */ + private void transfer(long length, InputStream in, OutputStream out) throws IOException { + byte[] buffer = new byte[1024]; + while (length > 0) { + int count = in.read(buffer, 0, (int) Math.min(buffer.length, length)); + if (count == -1) return; + out.write(buffer, 0, count); + length -= count; + } + } - private RecordedRequest readRequest(SpdyStream stream) throws IOException { - List spdyHeaders = stream.getRequestHeaders(); - List httpHeaders = new ArrayList(); - String method = "<:method omitted>"; - String path = "<:path omitted>"; - String version = "<:version omitted>"; - for (Iterator i = spdyHeaders.iterator(); i.hasNext(); ) { - String name = i.next(); - String value = i.next(); - if (":method".equals(name)) { - method = value; - } else if (":path".equals(name)) { - path = value; - } else if (":version".equals(name)) { - version = value; - } else { - httpHeaders.add(name + ": " + value); - } - } + /** + * Returns the text from {@code in} until the next "\r\n", or null if {@code + * in} is exhausted. + */ + private String readAsciiUntilCrlf(InputStream in) throws IOException { + StringBuilder builder = new StringBuilder(); + while (true) { + int c = in.read(); + if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { + builder.deleteCharAt(builder.length() - 1); + return builder.toString(); + } else if (c == -1) { + return builder.toString(); + } else { + builder.append((char) c); + } + } + } - InputStream bodyIn = stream.getInputStream(); - ByteArrayOutputStream bodyOut = new ByteArrayOutputStream(); - byte[] buffer = new byte[8192]; - int count; - while ((count = bodyIn.read(buffer)) != -1) { - bodyOut.write(buffer, 0, count); - } - bodyIn.close(); - String requestLine = method + ' ' + path + ' ' + version; - List chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY. - return new RecordedRequest(requestLine, httpHeaders, chunkSizes, bodyOut.size(), - bodyOut.toByteArray(), 0, socket); - } + private void readEmptyLine(InputStream in) throws IOException { + String line = readAsciiUntilCrlf(in); + if (line.length() != 0) throw new IllegalStateException("Expected empty but was: " + line); + } - private void writeResponse(SpdyStream stream, MockResponse response) throws IOException { - List spdyHeaders = new ArrayList(); - String[] statusParts = response.getStatus().split(" ", 2); - if (statusParts.length != 2) { - throw new AssertionError("Unexpected status: " + response.getStatus()); - } - spdyHeaders.add(":status"); - spdyHeaders.add(statusParts[1]); - spdyHeaders.add(":version"); - spdyHeaders.add(statusParts[0]); - for (String header : response.getHeaders()) { - String[] headerParts = header.split(":", 2); - if (headerParts.length != 2) { - throw new AssertionError("Unexpected header: " + header); - } - spdyHeaders.add(headerParts[0].toLowerCase(Locale.US).trim()); - spdyHeaders.add(headerParts[1].trim()); - } - byte[] body = response.getBody(); - stream.reply(spdyHeaders, body.length > 0); - if (body.length > 0) { - stream.getOutputStream().write(body); - stream.getOutputStream().close(); - } - } + /** + * Sets the dispatcher used to match incoming requests to mock responses. + * The default dispatcher simply serves a fixed sequence of responses from + * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the + * response based on timing or the content of the request. + */ + public void setDispatcher(Dispatcher dispatcher) { + if (dispatcher == null) throw new NullPointerException(); + this.dispatcher = dispatcher; + } + + /** An output stream that drops data after bodyLimit bytes. */ + private class TruncatingOutputStream extends ByteArrayOutputStream { + private long numBytesReceived = 0; + + @Override public void write(byte[] buffer, int offset, int len) { + numBytesReceived += len; + super.write(buffer, offset, Math.min(len, bodyLimit - count)); } - enum Transport { - HTTP_11, SPDY_3 + @Override public void write(int oneByte) { + numBytesReceived++; + if (count < bodyLimit) { + super.write(oneByte); + } } + } + + private static Runnable namedRunnable(final String name, final Runnable runnable) { + return new Runnable() { + public void run() { + String originalName = Thread.currentThread().getName(); + Thread.currentThread().setName(name); + try { + runnable.run(); + } finally { + Thread.currentThread().setName(originalName); + } + } + }; + } + + /** Processes HTTP requests layered over SPDY/3. */ + private class SpdySocketHandler implements IncomingStreamHandler { + private final Socket socket; + + private SpdySocketHandler(Socket socket) { + this.socket = socket; + } + + @Override public void receive(SpdyStream stream) throws IOException { + RecordedRequest request = readRequest(stream); + requestQueue.add(request); + MockResponse response; + try { + response = dispatcher.dispatch(request); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + writeResponse(stream, response); + logger.info("Received request: " + request + " and responded: " + response); + } + + private RecordedRequest readRequest(SpdyStream stream) throws IOException { + List spdyHeaders = stream.getRequestHeaders(); + List httpHeaders = new ArrayList(); + String method = "<:method omitted>"; + String path = "<:path omitted>"; + String version = "<:version omitted>"; + for (Iterator i = spdyHeaders.iterator(); i.hasNext(); ) { + String name = i.next(); + String value = i.next(); + if (":method".equals(name)) { + method = value; + } else if (":path".equals(name)) { + path = value; + } else if (":version".equals(name)) { + version = value; + } else { + httpHeaders.add(name + ": " + value); + } + } + + InputStream bodyIn = stream.getInputStream(); + ByteArrayOutputStream bodyOut = new ByteArrayOutputStream(); + byte[] buffer = new byte[8192]; + int count; + while ((count = bodyIn.read(buffer)) != -1) { + bodyOut.write(buffer, 0, count); + } + bodyIn.close(); + String requestLine = method + ' ' + path + ' ' + version; + List chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY. + return new RecordedRequest(requestLine, httpHeaders, chunkSizes, bodyOut.size(), + bodyOut.toByteArray(), 0, socket); + } + + private void writeResponse(SpdyStream stream, MockResponse response) throws IOException { + List spdyHeaders = new ArrayList(); + String[] statusParts = response.getStatus().split(" ", 2); + if (statusParts.length != 2) { + throw new AssertionError("Unexpected status: " + response.getStatus()); + } + spdyHeaders.add(":status"); + spdyHeaders.add(statusParts[1]); + spdyHeaders.add(":version"); + spdyHeaders.add(statusParts[0]); + for (String header : response.getHeaders()) { + String[] headerParts = header.split(":", 2); + if (headerParts.length != 2) { + throw new AssertionError("Unexpected header: " + header); + } + spdyHeaders.add(headerParts[0].toLowerCase(Locale.US).trim()); + spdyHeaders.add(headerParts[1].trim()); + } + byte[] body = response.getBody(); + stream.reply(spdyHeaders, body.length > 0); + if (body.length > 0) { + stream.getOutputStream().write(body); + stream.getOutputStream().close(); + } + } + } + + enum Transport { + HTTP_11, SPDY_3 + } } diff --git a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/QueueDispatcher.java b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/QueueDispatcher.java index 1448095d5..0f0cb280e 100644 --- a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/QueueDispatcher.java +++ b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/QueueDispatcher.java @@ -20,53 +20,52 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; /** - * Default dispatcher that processes a script of responses. Populate the script by calling - * {@link #enqueueResponse(MockResponse)}. + * Default dispatcher that processes a script of responses. Populate the script + * by calling {@link #enqueueResponse(MockResponse)}. */ public class QueueDispatcher extends Dispatcher { - protected final BlockingQueue responseQueue - = new LinkedBlockingQueue(); - private MockResponse failFastResponse; + protected final BlockingQueue responseQueue + = new LinkedBlockingQueue(); + private MockResponse failFastResponse; - @Override public MockResponse dispatch(RecordedRequest request) throws InterruptedException { - // to permit interactive/browser testing, ignore requests for favicons - final String requestLine = request.getRequestLine(); - if (requestLine != null && requestLine.equals("GET /favicon.ico HTTP/1.1")) { - System.out.println("served " + requestLine); - return new MockResponse() - .setResponseCode(HttpURLConnection.HTTP_NOT_FOUND); - } - - if (failFastResponse != null && responseQueue.peek() == null) { - // Fail fast if there's no response queued up. - return failFastResponse; - } - - return responseQueue.take(); + @Override public MockResponse dispatch(RecordedRequest request) throws InterruptedException { + // To permit interactive/browser testing, ignore requests for favicons. + final String requestLine = request.getRequestLine(); + if (requestLine != null && requestLine.equals("GET /favicon.ico HTTP/1.1")) { + System.out.println("served " + requestLine); + return new MockResponse().setResponseCode(HttpURLConnection.HTTP_NOT_FOUND); } - @Override public SocketPolicy peekSocketPolicy() { - MockResponse peek = responseQueue.peek(); - if (peek == null) { - return failFastResponse != null - ? failFastResponse.getSocketPolicy() - : SocketPolicy.KEEP_OPEN; - } - return peek.getSocketPolicy(); + if (failFastResponse != null && responseQueue.peek() == null) { + // Fail fast if there's no response queued up. + return failFastResponse; } - public void enqueueResponse(MockResponse response) { - responseQueue.add(response); - } + return responseQueue.take(); + } - public void setFailFast(boolean failFast) { - MockResponse failFastResponse = failFast - ? new MockResponse().setResponseCode(HttpURLConnection.HTTP_NOT_FOUND) - : null; - setFailFast(failFastResponse); + @Override public SocketPolicy peekSocketPolicy() { + MockResponse peek = responseQueue.peek(); + if (peek == null) { + return failFastResponse != null + ? failFastResponse.getSocketPolicy() + : SocketPolicy.KEEP_OPEN; } + return peek.getSocketPolicy(); + } - public void setFailFast(MockResponse failFastResponse) { - this.failFastResponse = failFastResponse; - } + public void enqueueResponse(MockResponse response) { + responseQueue.add(response); + } + + public void setFailFast(boolean failFast) { + MockResponse failFastResponse = failFast + ? new MockResponse().setResponseCode(HttpURLConnection.HTTP_NOT_FOUND) + : null; + setFailFast(failFastResponse); + } + + public void setFailFast(MockResponse failFastResponse) { + this.failFastResponse = failFastResponse; + } } diff --git a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/RecordedRequest.java b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/RecordedRequest.java index 7bdbd3c4a..aceacd184 100644 --- a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/RecordedRequest.java +++ b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/RecordedRequest.java @@ -22,146 +22,132 @@ import java.util.ArrayList; import java.util.List; import javax.net.ssl.SSLSocket; -/** - * An HTTP request that came into the mock web server. - */ +/** An HTTP request that came into the mock web server. */ public final class RecordedRequest { - private final String requestLine; - private final String method; - private final String path; - private final List headers; - private final List chunkSizes; - private final long bodySize; - private final byte[] body; - private final int sequenceNumber; - private final String sslProtocol; + private final String requestLine; + private final String method; + private final String path; + private final List headers; + private final List chunkSizes; + private final long bodySize; + private final byte[] body; + private final int sequenceNumber; + private final String sslProtocol; - public RecordedRequest(String requestLine, List headers, List chunkSizes, - long bodySize, byte[] body, int sequenceNumber, Socket socket) { - this.requestLine = requestLine; - this.headers = headers; - this.chunkSizes = chunkSizes; - this.bodySize = bodySize; - this.body = body; - this.sequenceNumber = sequenceNumber; + public RecordedRequest(String requestLine, List headers, List chunkSizes, + long bodySize, byte[] body, int sequenceNumber, Socket socket) { + this.requestLine = requestLine; + this.headers = headers; + this.chunkSizes = chunkSizes; + this.bodySize = bodySize; + this.body = body; + this.sequenceNumber = sequenceNumber; + this.sslProtocol = socket instanceof SSLSocket + ? ((SSLSocket) socket).getSession().getProtocol() + : null; - if (socket instanceof SSLSocket) { - SSLSocket sslSocket = (SSLSocket) socket; - sslProtocol = sslSocket.getSession().getProtocol(); - } else { - sslProtocol = null; - } - - if (requestLine != null) { - int methodEnd = requestLine.indexOf(' '); - int pathEnd = requestLine.indexOf(' ', methodEnd + 1); - this.method = requestLine.substring(0, methodEnd); - this.path = requestLine.substring(methodEnd + 1, pathEnd); - } else { - this.method = null; - this.path = null; - } + if (requestLine != null) { + int methodEnd = requestLine.indexOf(' '); + int pathEnd = requestLine.indexOf(' ', methodEnd + 1); + this.method = requestLine.substring(0, methodEnd); + this.path = requestLine.substring(methodEnd + 1, pathEnd); + } else { + this.method = null; + this.path = null; } + } - public String getRequestLine() { - return requestLine; - } + public String getRequestLine() { + return requestLine; + } - public String getMethod() { - return method; - } + public String getMethod() { + return method; + } - public String getPath() { - return path; - } + public String getPath() { + return path; + } - /** - * Returns all headers. - */ - public List getHeaders() { - return headers; - } + /** Returns all headers. */ + public List getHeaders() { + return headers; + } - /** - * Returns the first header named {@code name}, or null if no such header - * exists. - */ - public String getHeader(String name) { - name += ":"; - for (String header : headers) { - if (name.regionMatches(true, 0, header, 0, name.length())) { - return header.substring(name.length()).trim(); - } - } - return null; + /** + * Returns the first header named {@code name}, or null if no such header + * exists. + */ + public String getHeader(String name) { + name += ":"; + for (String header : headers) { + if (name.regionMatches(true, 0, header, 0, name.length())) { + return header.substring(name.length()).trim(); + } } + return null; + } - /** - * Returns the headers named {@code name}. - */ - public List getHeaders(String name) { - List result = new ArrayList(); - name += ":"; - for (String header : headers) { - if (name.regionMatches(true, 0, header, 0, name.length())) { - result.add(header.substring(name.length()).trim()); - } - } - return result; + /** Returns the headers named {@code name}. */ + public List getHeaders(String name) { + List result = new ArrayList(); + name += ":"; + for (String header : headers) { + if (name.regionMatches(true, 0, header, 0, name.length())) { + result.add(header.substring(name.length()).trim()); + } } + return result; + } - /** - * Returns the sizes of the chunks of this request's body, or an empty list - * if the request's body was empty or unchunked. - */ - public List getChunkSizes() { - return chunkSizes; - } + /** + * Returns the sizes of the chunks of this request's body, or an empty list + * if the request's body was empty or unchunked. + */ + public List getChunkSizes() { + return chunkSizes; + } - /** - * Returns the total size of the body of this POST request (before - * truncation). - */ - public long getBodySize() { - return bodySize; - } + /** + * Returns the total size of the body of this POST request (before + * truncation). + */ + public long getBodySize() { + return bodySize; + } - /** - * Returns the body of this POST request. This may be truncated. - */ - public byte[] getBody() { - return body; - } + /** Returns the body of this POST request. This may be truncated. */ + public byte[] getBody() { + return body; + } - /** - * Returns the body of this POST request decoded as a UTF-8 string. - */ - public String getUtf8Body() { - try { - return new String(body, "UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new AssertionError(); - } + /** Returns the body of this POST request decoded as a UTF-8 string. */ + public String getUtf8Body() { + try { + return new String(body, "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new AssertionError(); } + } - /** - * Returns the index of this request on its HTTP connection. Since a single - * HTTP connection may serve multiple requests, each request is assigned its - * own sequence number. - */ - public int getSequenceNumber() { - return sequenceNumber; - } + /** + * Returns the index of this request on its HTTP connection. Since a single + * HTTP connection may serve multiple requests, each request is assigned its + * own sequence number. + */ + public int getSequenceNumber() { + return sequenceNumber; + } - /** - * Returns the connection's SSL protocol like {@code TLSv1}, {@code SSLv3}, - * {@code NONE} or null if the connection doesn't use SSL. - */ - public String getSslProtocol() { - return sslProtocol; - } + /** + * Returns the connection's SSL protocol like {@code TLSv1}, {@code SSLv3}, + * {@code NONE} or null if the connection doesn't use SSL. + */ + public String getSslProtocol() { + return sslProtocol; + } - @Override public String toString() { - return requestLine; - } + @Override public String toString() { + return requestLine; + } } diff --git a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/SocketPolicy.java b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/SocketPolicy.java index 988bcd28c..7912f3a77 100644 --- a/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/SocketPolicy.java +++ b/mockwebserver/src/main/java/com/squareup/okhttp/mockwebserver/SocketPolicy.java @@ -16,54 +16,46 @@ package com.squareup.okhttp.mockwebserver; -/** - * What should be done with the incoming socket. - */ +/** What should be done with the incoming socket. */ public enum SocketPolicy { - /** - * Keep the socket open after the response. This is the default HTTP/1.1 - * behavior. - */ - KEEP_OPEN, + /** + * Keep the socket open after the response. This is the default HTTP/1.1 + * behavior. + */ + KEEP_OPEN, - /** - * Close the socket after the response. This is the default HTTP/1.0 - * behavior. - */ - DISCONNECT_AT_END, + /** + * Close the socket after the response. This is the default HTTP/1.0 + * behavior. + */ + DISCONNECT_AT_END, - /** - * Wrap the socket with SSL at the completion of this request/response - * pair. Used for CONNECT messages to tunnel SSL over an HTTP proxy. - */ - UPGRADE_TO_SSL_AT_END, + /** + * Wrap the socket with SSL at the completion of this request/response pair. + * Used for CONNECT messages to tunnel SSL over an HTTP proxy. + */ + UPGRADE_TO_SSL_AT_END, - /** - * Request immediate close of connection without even reading the - * request. - * - *

Use to simulate the real life case of losing connection - * because of bugger SSL server close connection when it seems - * something like a compression method or TLS extension it doesn't - * understand, instead of simply ignoring it like it should. - */ - DISCONNECT_AT_START, + /** + * Request immediate close of connection without even reading the request. Use + * to simulate buggy SSL servers closing connections in response to + * unrecognized TLS extensions. + */ + DISCONNECT_AT_START, - /** - * Don't trust the client during the SSL handshake. - */ - FAIL_HANDSHAKE, + /** Don't trust the client during the SSL handshake. */ + FAIL_HANDSHAKE, - /** - * Shutdown the socket input after sending the response. For testing bad - * behavior. - */ - SHUTDOWN_INPUT_AT_END, + /** + * Shutdown the socket input after sending the response. For testing bad + * behavior. + */ + SHUTDOWN_INPUT_AT_END, - /** - * Shutdown the socket output after sending the response. For testing bad - * behavior. - */ - SHUTDOWN_OUTPUT_AT_END + /** + * Shutdown the socket output after sending the response. For testing bad + * behavior. + */ + SHUTDOWN_OUTPUT_AT_END }