1
0
mirror of https://github.com/square/okhttp.git synced 2025-07-31 05:04:26 +03:00

Support client authentication in MockWebServer

Also expose the handshake in the RecordedResponse.

https://github.com/square/okhttp/issues/3934
This commit is contained in:
Jesse Wilson
2018-07-06 10:06:10 -04:00
parent dd720d260a
commit d908a676c2
13 changed files with 186 additions and 50 deletions

View File

@ -102,6 +102,10 @@ public final class MockWebServer extends ExternalResource implements Closeable {
Internal.initializeInstanceForTests(); Internal.initializeInstanceForTests();
} }
private static final int CLIENT_AUTH_NONE = 0;
private static final int CLIENT_AUTH_REQUESTED = 1;
private static final int CLIENT_AUTH_REQUIRED = 2;
private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() { private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() {
@Override public void checkClientTrusted(X509Certificate[] chain, String authType) @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
throws CertificateException { throws CertificateException {
@ -132,6 +136,7 @@ public final class MockWebServer extends ExternalResource implements Closeable {
private SSLSocketFactory sslSocketFactory; private SSLSocketFactory sslSocketFactory;
private ExecutorService executor; private ExecutorService executor;
private boolean tunnelProxy; private boolean tunnelProxy;
private int clientAuth = CLIENT_AUTH_NONE;
private Dispatcher dispatcher = new QueueDispatcher(); private Dispatcher dispatcher = new QueueDispatcher();
private int port = -1; private int port = -1;
@ -241,6 +246,36 @@ public final class MockWebServer extends ExternalResource implements Closeable {
this.tunnelProxy = tunnelProxy; this.tunnelProxy = tunnelProxy;
} }
/**
* Configure the server to not perform SSL authentication of the client. This leaves
* authentication to another layer such as in an HTTP cookie or header. This is the default and
* most common configuration.
*/
public void noClientAuth() {
this.clientAuth = CLIENT_AUTH_NONE;
}
/**
* Configure the server to {@linkplain SSLSocket#setWantClientAuth want client auth}. If the
* client presents a certificate that is {@linkplain TrustManager trusted} the handshake will
* proceed normally. The connection will also proceed normally if the client presents no
* certificate at all! But if the client presents an untrusted certificate the handshake will fail
* and no connection will be established.
*/
public void requestClientAuth() {
this.clientAuth = CLIENT_AUTH_REQUESTED;
}
/**
* Configure the server to {@linkplain SSLSocket#setNeedClientAuth need client auth}. If the
* client presents a certificate that is {@linkplain TrustManager trusted} the handshake will
* proceed normally. If the client presents an untrusted certificate or no certificate at all the
* handshake will fail and no connection will be established.
*/
public void requireClientAuth() {
this.clientAuth = CLIENT_AUTH_REQUIRED;
}
/** /**
* Awaits the next HTTP request, removes it, and returns it. Callers should use this to verify the * Awaits the next HTTP request, removes it, and returns it. Callers should use this to verify the
* request was sent as intended. This method will block until the request is available, possibly * request was sent as intended. This method will block until the request is available, possibly
@ -431,6 +466,11 @@ public final class MockWebServer extends ExternalResource implements Closeable {
raw.getPort(), true); raw.getPort(), true);
SSLSocket sslSocket = (SSLSocket) socket; SSLSocket sslSocket = (SSLSocket) socket;
sslSocket.setUseClientMode(false); sslSocket.setUseClientMode(false);
if (clientAuth == CLIENT_AUTH_REQUIRED) {
sslSocket.setNeedClientAuth(true);
} else if (clientAuth == CLIENT_AUTH_REQUESTED) {
sslSocket.setWantClientAuth(true);
}
openClientSockets.add(socket); openClientSockets.add(socket);
if (protocolNegotiationEnabled) { if (protocolNegotiationEnabled) {

View File

@ -16,9 +16,11 @@
package okhttp3.mockwebserver; package okhttp3.mockwebserver;
import java.io.IOException;
import java.net.Socket; import java.net.Socket;
import java.util.List; import java.util.List;
import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocket;
import okhttp3.Handshake;
import okhttp3.Headers; import okhttp3.Headers;
import okhttp3.HttpUrl; import okhttp3.HttpUrl;
import okhttp3.TlsVersion; import okhttp3.TlsVersion;
@ -30,11 +32,11 @@ public final class RecordedRequest {
private final String method; private final String method;
private final String path; private final String path;
private final Headers headers; private final Headers headers;
private final Handshake handshake;
private final List<Integer> chunkSizes; private final List<Integer> chunkSizes;
private final long bodySize; private final long bodySize;
private final Buffer body; private final Buffer body;
private final int sequenceNumber; private final int sequenceNumber;
private final TlsVersion tlsVersion;
private final HttpUrl requestUrl; private final HttpUrl requestUrl;
public RecordedRequest(String requestLine, Headers headers, List<Integer> chunkSizes, public RecordedRequest(String requestLine, Headers headers, List<Integer> chunkSizes,
@ -45,9 +47,15 @@ public final class RecordedRequest {
this.bodySize = bodySize; this.bodySize = bodySize;
this.body = body; this.body = body;
this.sequenceNumber = sequenceNumber; this.sequenceNumber = sequenceNumber;
this.tlsVersion = socket instanceof SSLSocket if (socket instanceof SSLSocket) {
? TlsVersion.forJavaName(((SSLSocket) socket).getSession().getProtocol()) try {
: null; this.handshake = Handshake.get(((SSLSocket) socket).getSession());
} catch (IOException e) {
throw new IllegalArgumentException(e);
}
} else {
this.handshake = null;
}
if (requestLine != null) { if (requestLine != null) {
int methodEnd = requestLine.indexOf(' '); int methodEnd = requestLine.indexOf(' ');
@ -128,7 +136,15 @@ public final class RecordedRequest {
/** Returns the connection's TLS version or null if the connection doesn't use SSL. */ /** Returns the connection's TLS version or null if the connection doesn't use SSL. */
public TlsVersion getTlsVersion() { public TlsVersion getTlsVersion() {
return tlsVersion; return handshake != null ? handshake.tlsVersion() : null;
}
/**
* Returns the TLS handshake of the connection that carried this request, or null if the request
* was received without TLS.
*/
public Handshake getHandshake() {
return handshake;
} }
@Override public String toString() { @Override public String toString() {

View File

@ -87,6 +87,10 @@ public final class SslClient {
* Configure the certificate chain to use when serving HTTPS responses. The first certificate is * Configure the certificate chain to use when serving HTTPS responses. The first certificate is
* the server's certificate, further certificates are included in the handshake so the client * the server's certificate, further certificates are included in the handshake so the client
* can build a trusted path to a CA certificate. * can build a trusted path to a CA certificate.
*
* <p>The chain should include all intermediate certificates but does not need the root
* certificate that we expect to be known by the remote peer. The peer already has that
* certificate so transmitting it is unnecessary.
*/ */
public Builder certificateChain(HeldCertificate localCert, HeldCertificate... chain) { public Builder certificateChain(HeldCertificate localCert, HeldCertificate... chain) {
X509Certificate[] certificates = new X509Certificate[chain.length]; X509Certificate[] certificates = new X509Certificate[chain.length];

View File

@ -33,10 +33,14 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.HttpsURLConnection;
import okhttp3.Handshake;
import okhttp3.Headers; import okhttp3.Headers;
import okhttp3.HttpUrl; import okhttp3.HttpUrl;
import okhttp3.Protocol; import okhttp3.Protocol;
import okhttp3.RecordingHostnameVerifier;
import okhttp3.internal.Util; import okhttp3.internal.Util;
import okhttp3.mockwebserver.internal.tls.SslClient;
import org.junit.After; import org.junit.After;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
@ -48,6 +52,7 @@ import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
@ -489,4 +494,77 @@ public final class MockWebServerTest {
assertEquals(1, server.protocols().size()); assertEquals(1, server.protocols().size());
assertEquals(Protocol.H2_PRIOR_KNOWLEDGE, server.protocols().get(0)); assertEquals(Protocol.H2_PRIOR_KNOWLEDGE, server.protocols().get(0));
} }
@Test public void https() throws Exception {
SslClient sslClient = SslClient.localhost();
server.useHttps(sslClient.socketFactory, false);
server.enqueue(new MockResponse().setBody("abc"));
HttpUrl url = server.url("/");
HttpsURLConnection connection = (HttpsURLConnection) url.url().openConnection();
connection.setSSLSocketFactory(sslClient.socketFactory);
connection.setHostnameVerifier(new RecordingHostnameVerifier());
assertEquals(HttpURLConnection.HTTP_OK, connection.getResponseCode());
BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream()));
assertEquals("abc", reader.readLine());
RecordedRequest request = server.takeRequest();
assertEquals("https", request.getRequestUrl().scheme());
Handshake handshake = request.getHandshake();
assertNotNull(handshake.tlsVersion());
assertNotNull(handshake.cipherSuite());
assertNotNull(handshake.localPrincipal());
assertEquals(1, handshake.localCertificates().size());
assertNull(handshake.peerPrincipal());
assertEquals(0, handshake.peerCertificates().size());
}
@Test public void httpsWithClientAuth() throws Exception {
HeldCertificate clientCa = new HeldCertificate.Builder()
.certificateAuthority(0)
.build();
HeldCertificate serverCa = new HeldCertificate.Builder()
.certificateAuthority(0)
.build();
HeldCertificate serverCertificate = new HeldCertificate.Builder()
.issuedBy(serverCa)
.addSubjectAlternativeName(server.getHostName())
.build();
SslClient serverSsl = new SslClient.Builder()
.addTrustedCertificate(clientCa.certificate())
.certificateChain(serverCertificate)
.build();
server.useHttps(serverSsl.socketFactory, false);
server.enqueue(new MockResponse().setBody("abc"));
server.requestClientAuth();
HeldCertificate clientCertificate = new HeldCertificate.Builder()
.issuedBy(clientCa)
.build();
SslClient clientSsl = new SslClient.Builder()
.addTrustedCertificate(serverCa.certificate())
.certificateChain(clientCertificate)
.build();
HttpUrl url = server.url("/");
HttpsURLConnection connection = (HttpsURLConnection) url.url().openConnection();
connection.setSSLSocketFactory(clientSsl.socketFactory);
connection.setHostnameVerifier(new RecordingHostnameVerifier());
assertEquals(HttpURLConnection.HTTP_OK, connection.getResponseCode());
BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream()));
assertEquals("abc", reader.readLine());
RecordedRequest request = server.takeRequest();
assertEquals("https", request.getRequestUrl().scheme());
Handshake handshake = request.getHandshake();
assertNotNull(handshake.tlsVersion());
assertNotNull(handshake.cipherSuite());
assertNotNull(handshake.localPrincipal());
assertEquals(1, handshake.localCertificates().size());
assertNotNull(handshake.peerPrincipal());
assertEquals(1, handshake.peerCertificates().size());
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2014 Square, Inc. * Copyright (C) 2018 Square, Inc.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package okhttp3.dnsoverhttps; package okhttp3.dnsoverhttps;
import java.net.InetAddress; import java.net.InetAddress;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2014 Square, Inc. * Copyright (C) 2018 Square, Inc.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2014 Square, Inc. * Copyright (C) 2018 Square, Inc.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2014 Square, Inc. * Copyright (C) 2018 Square, Inc.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package okhttp3.dnsoverhttps; package okhttp3.dnsoverhttps;
import java.net.InetAddress; import java.net.InetAddress;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2014 Square, Inc. * Copyright (C) 2018 Square, Inc.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2014 Square, Inc. * Copyright (C) 2018 Square, Inc.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,3 +1,18 @@
/*
* Copyright (C) 2018 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.internal.sse; package okhttp3.internal.sse;
import java.io.IOException; import java.io.IOException;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2014 Square, Inc. * Copyright (C) 2018 Square, Inc.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -15,24 +15,20 @@
*/ */
package okhttp3.internal.tls; package okhttp3.internal.tls;
import java.io.IOException;
import java.net.SocketException; import java.net.SocketException;
import java.security.GeneralSecurityException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
import javax.security.auth.x500.X500Principal; import javax.security.auth.x500.X500Principal;
import okhttp3.Call; import okhttp3.Call;
import okhttp3.DelegatingSSLSocketFactory;
import okhttp3.OkHttpClient; import okhttp3.OkHttpClient;
import okhttp3.Request; import okhttp3.Request;
import okhttp3.Response; import okhttp3.Response;
import okhttp3.mockwebserver.HeldCertificate;
import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.HeldCertificate;
import okhttp3.mockwebserver.internal.tls.SslClient; import okhttp3.mockwebserver.internal.tls.SslClient;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
@ -41,16 +37,11 @@ import org.junit.Test;
import static okhttp3.TestUtil.defaultClient; import static okhttp3.TestUtil.defaultClient;
import static okhttp3.internal.platform.PlatformTest.getPlatform; import static okhttp3.internal.platform.PlatformTest.getPlatform;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
public final class ClientAuthTest { public final class ClientAuthTest {
@Rule public final MockWebServer server = new MockWebServer(); @Rule public final MockWebServer server = new MockWebServer();
public enum ClientAuth {
NONE, WANTS, NEEDS
}
private HeldCertificate serverRootCa; private HeldCertificate serverRootCa;
private HeldCertificate serverIntermediateCa; private HeldCertificate serverIntermediateCa;
private HeldCertificate serverCert; private HeldCertificate serverCert;
@ -59,7 +50,7 @@ public final class ClientAuthTest {
private HeldCertificate clientCert; private HeldCertificate clientCert;
@Before @Before
public void setUp() throws GeneralSecurityException { public void setUp() {
serverRootCa = new HeldCertificate.Builder() serverRootCa = new HeldCertificate.Builder()
.serialNumber(1L) .serialNumber(1L)
.certificateAuthority(3) .certificateAuthority(3)
@ -106,9 +97,10 @@ public final class ClientAuthTest {
@Test public void clientAuthForWants() throws Exception { @Test public void clientAuthForWants() throws Exception {
OkHttpClient client = buildClient(clientCert, clientIntermediateCa); OkHttpClient client = buildClient(clientCert, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.WANTS); SSLSocketFactory socketFactory = buildServerSslSocketFactory();
server.useHttps(socketFactory, false); server.useHttps(socketFactory, false);
server.requestClientAuth();
server.enqueue(new MockResponse().setBody("abc")); server.enqueue(new MockResponse().setBody("abc"));
Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
@ -121,9 +113,10 @@ public final class ClientAuthTest {
@Test public void clientAuthForNeeds() throws Exception { @Test public void clientAuthForNeeds() throws Exception {
OkHttpClient client = buildClient(clientCert, clientIntermediateCa); OkHttpClient client = buildClient(clientCert, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.NEEDS); SSLSocketFactory socketFactory = buildServerSslSocketFactory();
server.useHttps(socketFactory, false); server.useHttps(socketFactory, false);
server.requireClientAuth();
server.enqueue(new MockResponse().setBody("abc")); server.enqueue(new MockResponse().setBody("abc"));
Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
@ -136,9 +129,10 @@ public final class ClientAuthTest {
@Test public void clientAuthSkippedForNone() throws Exception { @Test public void clientAuthSkippedForNone() throws Exception {
OkHttpClient client = buildClient(clientCert, clientIntermediateCa); OkHttpClient client = buildClient(clientCert, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.NONE); SSLSocketFactory socketFactory = buildServerSslSocketFactory();
server.useHttps(socketFactory, false); server.useHttps(socketFactory, false);
server.noClientAuth();
server.enqueue(new MockResponse().setBody("abc")); server.enqueue(new MockResponse().setBody("abc"));
Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
@ -151,9 +145,10 @@ public final class ClientAuthTest {
@Test public void missingClientAuthSkippedForWantsOnly() throws Exception { @Test public void missingClientAuthSkippedForWantsOnly() throws Exception {
OkHttpClient client = buildClient(null, clientIntermediateCa); OkHttpClient client = buildClient(null, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.WANTS); SSLSocketFactory socketFactory = buildServerSslSocketFactory();
server.useHttps(socketFactory, false); server.useHttps(socketFactory, false);
server.requestClientAuth();
server.enqueue(new MockResponse().setBody("abc")); server.enqueue(new MockResponse().setBody("abc"));
Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
@ -166,9 +161,10 @@ public final class ClientAuthTest {
@Test public void missingClientAuthFailsForNeeds() throws Exception { @Test public void missingClientAuthFailsForNeeds() throws Exception {
OkHttpClient client = buildClient(null, clientIntermediateCa); OkHttpClient client = buildClient(null, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.NEEDS); SSLSocketFactory socketFactory = buildServerSslSocketFactory();
server.useHttps(socketFactory, false); server.useHttps(socketFactory, false);
server.requireClientAuth();
Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
@ -177,8 +173,7 @@ public final class ClientAuthTest {
fail(); fail();
} catch (SSLHandshakeException expected) { } catch (SSLHandshakeException expected) {
} catch (SocketException expected) { } catch (SocketException expected) {
// JDK 9 assertEquals("jdk9", getPlatform());
assertTrue(getPlatform().equals("jdk9"));
} }
} }
@ -192,9 +187,10 @@ public final class ClientAuthTest {
OkHttpClient client = buildClient(clientCert, clientIntermediateCa); OkHttpClient client = buildClient(clientCert, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.NEEDS); SSLSocketFactory socketFactory = buildServerSslSocketFactory();
server.useHttps(socketFactory, false); server.useHttps(socketFactory, false);
server.requireClientAuth();
Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
@ -213,9 +209,10 @@ public final class ClientAuthTest {
OkHttpClient client = buildClient(clientCert2); OkHttpClient client = buildClient(clientCert2);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.NEEDS); SSLSocketFactory socketFactory = buildServerSslSocketFactory();
server.useHttps(socketFactory, false); server.useHttps(socketFactory, false);
server.requireClientAuth();
Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
@ -224,12 +221,11 @@ public final class ClientAuthTest {
fail(); fail();
} catch (SSLHandshakeException expected) { } catch (SSLHandshakeException expected) {
} catch (SocketException expected) { } catch (SocketException expected) {
// JDK 9 assertEquals("jdk9", getPlatform());
assertTrue(getPlatform().equals("jdk9"));
} }
} }
public OkHttpClient buildClient(HeldCertificate cert, HeldCertificate... chain) { private OkHttpClient buildClient(HeldCertificate cert, HeldCertificate... chain) {
SslClient.Builder sslClientBuilder = new SslClient.Builder() SslClient.Builder sslClientBuilder = new SslClient.Builder()
.addTrustedCertificate(serverRootCa.certificate()); .addTrustedCertificate(serverRootCa.certificate());
@ -243,7 +239,7 @@ public final class ClientAuthTest {
.build(); .build();
} }
public SSLSocketFactory buildServerSslSocketFactory(final ClientAuth clientAuth) { private SSLSocketFactory buildServerSslSocketFactory() {
// The test uses JDK default SSL Context instead of the Platform provided one // The test uses JDK default SSL Context instead of the Platform provided one
// as Conscrypt seems to have some differences, we only want to test client side here. // as Conscrypt seems to have some differences, we only want to test client side here.
SslClient serverSslClient = new SslClient.Builder() SslClient serverSslClient = new SslClient.Builder()
@ -252,18 +248,7 @@ public final class ClientAuthTest {
.certificateChain(serverCert, serverIntermediateCa) .certificateChain(serverCert, serverIntermediateCa)
.sslContext(getSslContext()) .sslContext(getSslContext())
.build(); .build();
return serverSslClient.socketFactory;
return new DelegatingSSLSocketFactory(serverSslClient.socketFactory) {
@Override protected SSLSocket configureSocket(SSLSocket sslSocket) throws IOException {
if (clientAuth == ClientAuth.NEEDS) {
sslSocket.setNeedClientAuth(true);
} else if (clientAuth == ClientAuth.WANTS) {
sslSocket.setWantClientAuth(true);
}
return super.configureSocket(sslSocket);
}
};
} }
private SSLContext getSslContext() { private SSLContext getSslContext() {