1
0
mirror of https://github.com/square/okhttp.git synced 2026-01-27 04:22:07 +03:00

Support SSL handshake failures.

Also add an API to RecordedRequest to help differentiate between
original TLSv1 requests and SSLv3 fallback requests.


git-svn-id: https://mockwebserver.googlecode.com/svn/trunk@11 cf848351-439f-e86a-257f-67fa721851d5
This commit is contained in:
jessewilson@google.com
2011-12-20 01:07:46 +00:00
parent f003d2260d
commit 7bb99fd0bd
3 changed files with 71 additions and 7 deletions

View File

@@ -17,6 +17,7 @@
package com.google.mockwebserver;
import static com.google.mockwebserver.SocketPolicy.DISCONNECT_AT_START;
import static com.google.mockwebserver.SocketPolicy.FAIL_HANDSHAKE;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
@@ -33,6 +34,8 @@ import java.net.Socket;
import java.net.SocketException;
import java.net.URL;
import java.net.UnknownHostException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
@@ -47,8 +50,11 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
/**
* A scriptable web server. Callers supply canned responses and the server
@@ -267,6 +273,11 @@ public final class MockWebServer {
if (tunnelProxy) {
createTunnel();
}
MockResponse response = responseQueue.peek();
if (response != null && response.getSocketPolicy() == FAIL_HANDSHAKE) {
processHandshakeFailure(raw, sequenceNumber++);
return;
}
socket = sslSocketFactory.createSocket(
raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
((SSLSocket) socket).setUseClientMode(false);
@@ -279,7 +290,7 @@ public final class MockWebServer {
InputStream in = new BufferedInputStream(socket.getInputStream());
OutputStream out = new BufferedOutputStream(socket.getOutputStream());
while (!responseQueue.isEmpty() && processOneRequest(in, out, socket)) {}
while (!responseQueue.isEmpty() && processOneRequest(socket, in, out)) {}
if (sequenceNumber == 0) {
logger.warning("MockWebServer connection didn't make a request");
@@ -301,7 +312,7 @@ public final class MockWebServer {
private void createTunnel() throws IOException, InterruptedException {
while (true) {
MockResponse connect = responseQueue.peek();
if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) {
if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
throw new IllegalStateException("Tunnel without any CONNECT!");
}
if (connect.getSocketPolicy() == SocketPolicy.UPGRADE_TO_SSL_AT_END) {
@@ -314,9 +325,9 @@ public final class MockWebServer {
* Reads a request and writes its response. Returns true if a request
* was processed.
*/
private boolean processOneRequest(InputStream in, OutputStream out, Socket socket)
private boolean processOneRequest(Socket socket, InputStream in, OutputStream out)
throws IOException, InterruptedException {
RecordedRequest request = readRequest(in, sequenceNumber);
RecordedRequest request = readRequest(socket, in, sequenceNumber);
if (request == null) {
return false;
}
@@ -336,10 +347,40 @@ public final class MockWebServer {
}));
}
private void processHandshakeFailure(Socket raw, int sequenceNumber) throws Exception {
responseQueue.take();
X509TrustManager untrusted = new X509TrustManager() {
@Override public void checkClientTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
throw new CertificateException();
}
@Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
throw new AssertionError();
}
@Override public X509Certificate[] getAcceptedIssuers() {
throw new AssertionError();
}
};
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, new TrustManager[] { untrusted }, new java.security.SecureRandom());
SSLSocketFactory sslSocketFactory = context.getSocketFactory();
SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
try {
socket.startHandshake(); // we're testing a handshake failure
throw new AssertionError();
} catch (IOException expected) {
}
socket.close();
requestCount.incrementAndGet();
requestQueue.add(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket));
}
/**
* @param sequenceNumber the index of this request on this connection.
*/
private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
private RecordedRequest readRequest(Socket socket, InputStream in, int sequenceNumber)
throws IOException {
String request;
try {
request = readAsciiUntilCrlf(in);
@@ -401,7 +442,7 @@ public final class MockWebServer {
}
return new RecordedRequest(request, headers, chunkSizes,
requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket);
}
/**

View File

@@ -16,7 +16,9 @@
package com.google.mockwebserver;
import java.net.Socket;
import java.util.List;
import javax.net.ssl.SSLSocket;
/**
* An HTTP request that came into the mock web server.
@@ -28,15 +30,23 @@ public final class RecordedRequest {
private final int bodySize;
private final byte[] body;
private final int sequenceNumber;
private final String sslProtocol;
RecordedRequest(String requestLine, List<String> headers, List<Integer> chunkSizes,
int bodySize, byte[] body, int sequenceNumber) {
int bodySize, byte[] body, int sequenceNumber, Socket socket) {
this.requestLine = requestLine;
this.headers = headers;
this.chunkSizes = chunkSizes;
this.bodySize = bodySize;
this.body = body;
this.sequenceNumber = sequenceNumber;
if (socket instanceof SSLSocket) {
SSLSocket sslSocket = (SSLSocket) socket;
sslProtocol = sslSocket.getSession().getProtocol();
} else {
sslProtocol = null;
}
}
public String getRequestLine() {
@@ -79,6 +89,14 @@ public final class RecordedRequest {
return sequenceNumber;
}
/**
* Returns the connection's SSL protocol like {@code TLSv1}, {@code SSLv3},
* {@code NONE} or null if the connection doesn't use SSL.
*/
public String getSslProtocol() {
return sslProtocol;
}
@Override public String toString() {
return requestLine;
}

View File

@@ -50,6 +50,11 @@ public enum SocketPolicy {
*/
DISCONNECT_AT_START,
/**
* Don't trust the client during the SSL handshake.
*/
FAIL_HANDSHAKE,
/**
* Shutdown the socket input after sending the response. For testing bad
* behavior.