diff --git a/okhttp-tests/src/test/java/okhttp3/DelegatingSSLSocketFactory.java b/okhttp-tests/src/test/java/okhttp3/DelegatingSSLSocketFactory.java index 00a968e7b..5a14d0fbd 100644 --- a/okhttp-tests/src/test/java/okhttp3/DelegatingSSLSocketFactory.java +++ b/okhttp-tests/src/test/java/okhttp3/DelegatingSSLSocketFactory.java @@ -34,51 +34,43 @@ public class DelegatingSSLSocketFactory extends SSLSocketFactory { this.delegate = delegate; } - @Override - public SSLSocket createSocket() throws IOException { + @Override public SSLSocket createSocket() throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(); return configureSocket(sslSocket); } - @Override - public SSLSocket createSocket(String host, int port) throws IOException, UnknownHostException { + @Override public SSLSocket createSocket(String host, int port) throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port); return configureSocket(sslSocket); } - @Override - public SSLSocket createSocket(String host, int port, InetAddress localAddress, int localPort) - throws IOException, UnknownHostException { + @Override public SSLSocket createSocket( + String host, int port, InetAddress localAddress, int localPort) throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port, localAddress, localPort); return configureSocket(sslSocket); } - @Override - public SSLSocket createSocket(InetAddress host, int port) throws IOException { + @Override public SSLSocket createSocket(InetAddress host, int port) throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port); return configureSocket(sslSocket); } - @Override - public SSLSocket createSocket(InetAddress host, int port, InetAddress localAddress, int localPort) - throws IOException { + @Override public SSLSocket createSocket( + InetAddress host, int port, InetAddress localAddress, int localPort) throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port, localAddress, localPort); return configureSocket(sslSocket); } - @Override - public String[] getDefaultCipherSuites() { + @Override public String[] getDefaultCipherSuites() { return delegate.getDefaultCipherSuites(); } - @Override - public String[] getSupportedCipherSuites() { + @Override public String[] getSupportedCipherSuites() { return delegate.getSupportedCipherSuites(); } - @Override - public SSLSocket createSocket(Socket socket, String host, int port, boolean autoClose) - throws IOException { + @Override public SSLSocket createSocket( + Socket socket, String host, int port, boolean autoClose) throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(socket, host, port, autoClose); return configureSocket(sslSocket); } diff --git a/okhttp/src/main/java/okhttp3/OkHttpClient.java b/okhttp/src/main/java/okhttp3/OkHttpClient.java index 43a2a7906..bcbf14d25 100644 --- a/okhttp/src/main/java/okhttp3/OkHttpClient.java +++ b/okhttp/src/main/java/okhttp3/OkHttpClient.java @@ -28,8 +28,10 @@ import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.X509TrustManager; import okhttp3.internal.Internal; import okhttp3.internal.InternalCache; +import okhttp3.internal.Platform; import okhttp3.internal.RouteDatabase; import okhttp3.internal.Util; import okhttp3.internal.http.StreamAllocation; @@ -130,6 +132,7 @@ public final class OkHttpClient implements Cloneable, Call.Factory { final InternalCache internalCache; final SocketFactory socketFactory; final SSLSocketFactory sslSocketFactory; + final X509TrustManager trustManager; final HostnameVerifier hostnameVerifier; final CertificatePinner certificatePinner; final Authenticator proxyAuthenticator; @@ -160,7 +163,7 @@ public final class OkHttpClient implements Cloneable, Call.Factory { this.internalCache = builder.internalCache; this.socketFactory = builder.socketFactory; - boolean isTLS = true; + boolean isTLS = false; for (ConnectionSpec spec : connectionSpecs) { isTLS = isTLS || spec.isTls(); } @@ -176,6 +179,16 @@ public final class OkHttpClient implements Cloneable, Call.Factory { throw new AssertionError(); // The system has no TLS. Just give up. } } + if (this.sslSocketFactory != null) { + this.trustManager = Platform.get().trustManager(sslSocketFactory); + if (trustManager == null) { + throw new IllegalStateException("Unable to extract the trust manager on " + Platform.get() + + ", sslSocketFactory is " + sslSocketFactory.getClass()); + } + } else { + this.trustManager = null; + } + this.hostnameVerifier = builder.hostnameVerifier; this.certificatePinner = builder.certificatePinner; this.proxyAuthenticator = builder.proxyAuthenticator; diff --git a/okhttp/src/main/java/okhttp3/internal/Platform.java b/okhttp/src/main/java/okhttp3/internal/Platform.java index 9e5e5545d..bf5144155 100644 --- a/okhttp/src/main/java/okhttp3/internal/Platform.java +++ b/okhttp/src/main/java/okhttp3/internal/Platform.java @@ -18,6 +18,7 @@ package okhttp3.internal; import android.util.Log; import java.io.IOException; +import java.lang.reflect.Field; import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -29,6 +30,8 @@ import java.util.ArrayList; import java.util.List; import java.util.logging.Level; import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.X509TrustManager; import okhttp3.Protocol; import okio.Buffer; @@ -55,6 +58,11 @@ import static okhttp3.internal.Internal.logger; * unstable. * * Supported on OpenJDK 7 and 8 (via the JettyALPN-boot library). + * + *

Trust Manager Extraction

+ * + *

Supported on Android 2.3+ and OpenJDK 7+. There are no public APIs to recover the trust + * manager that was used to create an {@link SSLSocketFactory}. */ public class Platform { private static final Platform PLATFORM = findPlatform(); @@ -78,6 +86,10 @@ public class Platform { public void untagSocket(Socket socket) throws SocketException { } + public X509TrustManager trustManager(SSLSocketFactory sslSocketFactory) { + return null; + } + /** * Configure TLS extensions on {@code sslSocket} for {@code route}. * @@ -112,11 +124,13 @@ public class Platform { private static Platform findPlatform() { // Attempt to find Android 2.3+ APIs. try { + Class sslParametersClass; try { - Class.forName("com.android.org.conscrypt.OpenSSLSocketImpl"); + sslParametersClass = Class.forName("com.android.org.conscrypt.SSLParametersImpl"); } catch (ClassNotFoundException e) { // Older platform before being unbundled. - Class.forName("org.apache.harmony.xnet.provider.jsse.OpenSSLSocketImpl"); + sslParametersClass = Class.forName( + "org.apache.harmony.xnet.provider.jsse.SSLParametersImpl"); } OptionalMethod setUseSessionTickets @@ -144,25 +158,34 @@ public class Platform { } catch (ClassNotFoundException | NoSuchMethodException ignored) { } - return new Android(setUseSessionTickets, setHostname, trafficStatsTagSocket, - trafficStatsUntagSocket, getAlpnSelectedProtocol, setAlpnProtocols); + return new Android(sslParametersClass, setUseSessionTickets, setHostname, + trafficStatsTagSocket, trafficStatsUntagSocket, getAlpnSelectedProtocol, + setAlpnProtocols); } catch (ClassNotFoundException ignored) { // This isn't an Android runtime. } - // Find Jetty's ALPN extension for OpenJDK. + // Find an Oracle JDK. try { - String negoClassName = "org.eclipse.jetty.alpn.ALPN"; - Class negoClass = Class.forName(negoClassName); - Class providerClass = Class.forName(negoClassName + "$Provider"); - Class clientProviderClass = Class.forName(negoClassName + "$ClientProvider"); - Class serverProviderClass = Class.forName(negoClassName + "$ServerProvider"); - Method putMethod = negoClass.getMethod("put", SSLSocket.class, providerClass); - Method getMethod = negoClass.getMethod("get", SSLSocket.class); - Method removeMethod = negoClass.getMethod("remove", SSLSocket.class); - return new JdkWithJettyBootPlatform( - putMethod, getMethod, removeMethod, clientProviderClass, serverProviderClass); - } catch (ClassNotFoundException | NoSuchMethodException ignored) { + Class sslContextClass = Class.forName("sun.security.ssl.SSLContextImpl"); + + // Find Jetty's ALPN extension for OpenJDK. + try { + String negoClassName = "org.eclipse.jetty.alpn.ALPN"; + Class negoClass = Class.forName(negoClassName); + Class providerClass = Class.forName(negoClassName + "$Provider"); + Class clientProviderClass = Class.forName(negoClassName + "$ClientProvider"); + Class serverProviderClass = Class.forName(negoClassName + "$ServerProvider"); + Method putMethod = negoClass.getMethod("put", SSLSocket.class, providerClass); + Method getMethod = negoClass.getMethod("get", SSLSocket.class); + Method removeMethod = negoClass.getMethod("remove", SSLSocket.class); + return new JdkWithJettyBootPlatform(sslContextClass, + putMethod, getMethod, removeMethod, clientProviderClass, serverProviderClass); + } catch (ClassNotFoundException | NoSuchMethodException ignored) { + } + + return new JdkPlatform(sslContextClass); + } catch (ClassNotFoundException ignored) { } return new Platform(); @@ -172,6 +195,7 @@ public class Platform { private static class Android extends Platform { private static final int MAX_LOG_LENGTH = 4000; + private final Class sslParametersClass; private final OptionalMethod setUseSessionTickets; private final OptionalMethod setHostname; @@ -183,9 +207,11 @@ public class Platform { private final OptionalMethod getAlpnSelectedProtocol; private final OptionalMethod setAlpnProtocols; - public Android(OptionalMethod setUseSessionTickets, OptionalMethod setHostname, - Method trafficStatsTagSocket, Method trafficStatsUntagSocket, - OptionalMethod getAlpnSelectedProtocol, OptionalMethod setAlpnProtocols) { + public Android(Class sslParametersClass, OptionalMethod setUseSessionTickets, + OptionalMethod setHostname, Method trafficStatsTagSocket, + Method trafficStatsUntagSocket, OptionalMethod getAlpnSelectedProtocol, + OptionalMethod setAlpnProtocols) { + this.sslParametersClass = sslParametersClass; this.setUseSessionTickets = setUseSessionTickets; this.setHostname = setHostname; this.trafficStatsTagSocket = trafficStatsTagSocket; @@ -210,6 +236,17 @@ public class Platform { } } + @Override public X509TrustManager trustManager(SSLSocketFactory sslSocketFactory) { + Object context = readFieldOrNull(sslSocketFactory, sslParametersClass, "sslParameters"); + if (context == null) return null; + + X509TrustManager x509TrustManager = readFieldOrNull( + context, X509TrustManager.class, "x509TrustManager"); + if (x509TrustManager != null) return x509TrustManager; + + return readFieldOrNull(context, X509TrustManager.class, "trustManager"); + } + @Override public void configureTlsExtensions( SSLSocket sslSocket, String hostname, List protocols) { // Enable SNI and session tickets. @@ -271,18 +308,34 @@ public class Platform { } } + /** JDK 1.7 or better. */ + private static class JdkPlatform extends Platform { + private final Class sslContextClass; + + public JdkPlatform(Class sslContextClass) { + this.sslContextClass = sslContextClass; + } + + @Override public X509TrustManager trustManager(SSLSocketFactory sslSocketFactory) { + Object context = readFieldOrNull(sslSocketFactory, sslContextClass, "context"); + if (context == null) return null; + return readFieldOrNull(context, X509TrustManager.class, "trustManager"); + } + } + /** * OpenJDK 7+ with {@code org.mortbay.jetty.alpn/alpn-boot} in the boot class path. */ - private static class JdkWithJettyBootPlatform extends Platform { + private static class JdkWithJettyBootPlatform extends JdkPlatform { private final Method putMethod; private final Method getMethod; private final Method removeMethod; private final Class clientProviderClass; private final Class serverProviderClass; - public JdkWithJettyBootPlatform(Method putMethod, Method getMethod, Method removeMethod, - Class clientProviderClass, Class serverProviderClass) { + public JdkWithJettyBootPlatform(Class sslContextClass, Method putMethod, Method getMethod, + Method removeMethod, Class clientProviderClass, Class serverProviderClass) { + super(sslContextClass); this.putMethod = putMethod; this.getMethod = getMethod; this.removeMethod = removeMethod; @@ -394,4 +447,27 @@ public class Platform { } return result.readByteArray(); } + + static T readFieldOrNull(Object instance, Class fieldType, String fieldName) { + for (Class c = instance.getClass(); c != Object.class; c = c.getSuperclass()) { + try { + Field field = c.getDeclaredField(fieldName); + field.setAccessible(true); + Object value = field.get(instance); + if (value == null || !fieldType.isInstance(value)) return null; + return fieldType.cast(value); + } catch (NoSuchFieldException ignored) { + } catch (IllegalAccessException e) { + throw new AssertionError(); + } + } + + // Didn't find the field we wanted. As a last gasp attempt, try to find the value on a delegate. + if (!fieldName.equals("delegate")) { + Object delegate = readFieldOrNull(instance, Object.class, "delegate"); + if (delegate != null) return readFieldOrNull(delegate, fieldType, fieldName); + } + + return null; + } }