diff --git a/src/main/java/com/google/mockwebserver/MockWebServer.java b/src/main/java/com/google/mockwebserver/MockWebServer.java index ba92458cb..65f4547f1 100644 --- a/src/main/java/com/google/mockwebserver/MockWebServer.java +++ b/src/main/java/com/google/mockwebserver/MockWebServer.java @@ -17,6 +17,7 @@ package com.google.mockwebserver; import static com.google.mockwebserver.SocketPolicy.DISCONNECT_AT_START; +import static com.google.mockwebserver.SocketPolicy.FAIL_HANDSHAKE; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.ByteArrayOutputStream; @@ -33,6 +34,8 @@ import java.net.Socket; import java.net.SocketException; import java.net.URL; import java.net.UnknownHostException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; @@ -47,8 +50,11 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; +import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; /** * A scriptable web server. Callers supply canned responses and the server @@ -267,6 +273,11 @@ public final class MockWebServer { if (tunnelProxy) { createTunnel(); } + MockResponse response = responseQueue.peek(); + if (response != null && response.getSocketPolicy() == FAIL_HANDSHAKE) { + processHandshakeFailure(raw, sequenceNumber++); + return; + } socket = sslSocketFactory.createSocket( raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); ((SSLSocket) socket).setUseClientMode(false); @@ -279,7 +290,7 @@ public final class MockWebServer { InputStream in = new BufferedInputStream(socket.getInputStream()); OutputStream out = new BufferedOutputStream(socket.getOutputStream()); - while (!responseQueue.isEmpty() && processOneRequest(in, out, socket)) {} + while (!responseQueue.isEmpty() && processOneRequest(socket, in, out)) {} if (sequenceNumber == 0) { logger.warning("MockWebServer connection didn't make a request"); @@ -301,7 +312,7 @@ public final class MockWebServer { private void createTunnel() throws IOException, InterruptedException { while (true) { MockResponse connect = responseQueue.peek(); - if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) { + if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) { throw new IllegalStateException("Tunnel without any CONNECT!"); } if (connect.getSocketPolicy() == SocketPolicy.UPGRADE_TO_SSL_AT_END) { @@ -314,9 +325,9 @@ public final class MockWebServer { * Reads a request and writes its response. Returns true if a request * was processed. */ - private boolean processOneRequest(InputStream in, OutputStream out, Socket socket) + private boolean processOneRequest(Socket socket, InputStream in, OutputStream out) throws IOException, InterruptedException { - RecordedRequest request = readRequest(in, sequenceNumber); + RecordedRequest request = readRequest(socket, in, sequenceNumber); if (request == null) { return false; } @@ -336,10 +347,40 @@ public final class MockWebServer { })); } + private void processHandshakeFailure(Socket raw, int sequenceNumber) throws Exception { + responseQueue.take(); + X509TrustManager untrusted = new X509TrustManager() { + @Override public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateException(); + } + @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { + throw new AssertionError(); + } + @Override public X509Certificate[] getAcceptedIssuers() { + throw new AssertionError(); + } + }; + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, new TrustManager[] { untrusted }, new java.security.SecureRandom()); + SSLSocketFactory sslSocketFactory = context.getSocketFactory(); + SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket( + raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); + try { + socket.startHandshake(); // we're testing a handshake failure + throw new AssertionError(); + } catch (IOException expected) { + } + socket.close(); + requestCount.incrementAndGet(); + requestQueue.add(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket)); + } + /** * @param sequenceNumber the index of this request on this connection. */ - private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException { + private RecordedRequest readRequest(Socket socket, InputStream in, int sequenceNumber) + throws IOException { String request; try { request = readAsciiUntilCrlf(in); @@ -401,7 +442,7 @@ public final class MockWebServer { } return new RecordedRequest(request, headers, chunkSizes, - requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber); + requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket); } /** diff --git a/src/main/java/com/google/mockwebserver/RecordedRequest.java b/src/main/java/com/google/mockwebserver/RecordedRequest.java index 8f0908457..a06c0bccb 100644 --- a/src/main/java/com/google/mockwebserver/RecordedRequest.java +++ b/src/main/java/com/google/mockwebserver/RecordedRequest.java @@ -16,7 +16,9 @@ package com.google.mockwebserver; +import java.net.Socket; import java.util.List; +import javax.net.ssl.SSLSocket; /** * An HTTP request that came into the mock web server. @@ -28,15 +30,23 @@ public final class RecordedRequest { private final int bodySize; private final byte[] body; private final int sequenceNumber; + private final String sslProtocol; RecordedRequest(String requestLine, List headers, List chunkSizes, - int bodySize, byte[] body, int sequenceNumber) { + int bodySize, byte[] body, int sequenceNumber, Socket socket) { this.requestLine = requestLine; this.headers = headers; this.chunkSizes = chunkSizes; this.bodySize = bodySize; this.body = body; this.sequenceNumber = sequenceNumber; + + if (socket instanceof SSLSocket) { + SSLSocket sslSocket = (SSLSocket) socket; + sslProtocol = sslSocket.getSession().getProtocol(); + } else { + sslProtocol = null; + } } public String getRequestLine() { @@ -79,6 +89,14 @@ public final class RecordedRequest { return sequenceNumber; } + /** + * Returns the connection's SSL protocol like {@code TLSv1}, {@code SSLv3}, + * {@code NONE} or null if the connection doesn't use SSL. + */ + public String getSslProtocol() { + return sslProtocol; + } + @Override public String toString() { return requestLine; } diff --git a/src/main/java/com/google/mockwebserver/SocketPolicy.java b/src/main/java/com/google/mockwebserver/SocketPolicy.java index d256a45ba..3a6797b2b 100644 --- a/src/main/java/com/google/mockwebserver/SocketPolicy.java +++ b/src/main/java/com/google/mockwebserver/SocketPolicy.java @@ -50,6 +50,11 @@ public enum SocketPolicy { */ DISCONNECT_AT_START, + /** + * Don't trust the client during the SSL handshake. + */ + FAIL_HANDSHAKE, + /** * Shutdown the socket input after sending the response. For testing bad * behavior.