diff --git a/okhttp-tests/src/test/java/com/squareup/okhttp/DelegatingSSLSocketFactory.java b/okhttp-tests/src/test/java/com/squareup/okhttp/DelegatingSSLSocketFactory.java index 38a5de8f8..f72bd1ace 100644 --- a/okhttp-tests/src/test/java/com/squareup/okhttp/DelegatingSSLSocketFactory.java +++ b/okhttp-tests/src/test/java/com/squareup/okhttp/DelegatingSSLSocketFactory.java @@ -35,51 +35,43 @@ public class DelegatingSSLSocketFactory extends SSLSocketFactory { this.delegate = delegate; } - @Override - public SSLSocket createSocket() throws IOException { + @Override public SSLSocket createSocket() throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(); return configureSocket(sslSocket); } - @Override - public SSLSocket createSocket(String host, int port) throws IOException, UnknownHostException { + @Override public SSLSocket createSocket(String host, int port) throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port); return configureSocket(sslSocket); } - @Override - public SSLSocket createSocket(String host, int port, InetAddress localAddress, int localPort) - throws IOException, UnknownHostException { + @Override public SSLSocket createSocket( + String host, int port, InetAddress localAddress, int localPort) throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port, localAddress, localPort); return configureSocket(sslSocket); } - @Override - public SSLSocket createSocket(InetAddress host, int port) throws IOException { + @Override public SSLSocket createSocket(InetAddress host, int port) throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port); return configureSocket(sslSocket); } - @Override - public SSLSocket createSocket(InetAddress host, int port, InetAddress localAddress, int localPort) - throws IOException { + @Override public SSLSocket createSocket( + InetAddress host, int port, InetAddress localAddress, int localPort) throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port, localAddress, localPort); return configureSocket(sslSocket); } - @Override - public String[] getDefaultCipherSuites() { + @Override public String[] getDefaultCipherSuites() { return delegate.getDefaultCipherSuites(); } - @Override - public String[] getSupportedCipherSuites() { + @Override public String[] getSupportedCipherSuites() { return delegate.getSupportedCipherSuites(); } - @Override - public SSLSocket createSocket(Socket socket, String host, int port, boolean autoClose) - throws IOException { + @Override public SSLSocket createSocket( + Socket socket, String host, int port, boolean autoClose) throws IOException { SSLSocket sslSocket = (SSLSocket) delegate.createSocket(socket, host, port, autoClose); return configureSocket(sslSocket); } diff --git a/okhttp/src/main/java/com/squareup/okhttp/OkHttpClient.java b/okhttp/src/main/java/com/squareup/okhttp/OkHttpClient.java index aabc2d2de..12a0c58bf 100644 --- a/okhttp/src/main/java/com/squareup/okhttp/OkHttpClient.java +++ b/okhttp/src/main/java/com/squareup/okhttp/OkHttpClient.java @@ -17,6 +17,7 @@ package com.squareup.okhttp; import com.squareup.okhttp.internal.Internal; import com.squareup.okhttp.internal.InternalCache; +import com.squareup.okhttp.internal.Platform; import com.squareup.okhttp.internal.RouteDatabase; import com.squareup.okhttp.internal.Util; import com.squareup.okhttp.internal.http.AuthenticatorAdapter; @@ -38,6 +39,7 @@ import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.X509TrustManager; /** * Configures and creates HTTP connections. Most applications can use a single @@ -116,6 +118,7 @@ public class OkHttpClient implements Cloneable { /** Lazily-initialized. */ private static SSLSocketFactory defaultSslSocketFactory; + private static X509TrustManager defaultTrustManager; private final RouteDatabase routeDatabase; private Dispatcher dispatcher; @@ -133,6 +136,7 @@ public class OkHttpClient implements Cloneable { private SocketFactory socketFactory; private SSLSocketFactory sslSocketFactory; + private X509TrustManager trustManager; private HostnameVerifier hostnameVerifier; private CertificatePinner certificatePinner; private Authenticator authenticator; @@ -164,6 +168,7 @@ public class OkHttpClient implements Cloneable { this.internalCache = cache != null ? cache.internalCache : okHttpClient.internalCache; this.socketFactory = okHttpClient.socketFactory; this.sslSocketFactory = okHttpClient.sslSocketFactory; + this.trustManager = okHttpClient.trustManager; this.hostnameVerifier = okHttpClient.hostnameVerifier; this.certificatePinner = okHttpClient.certificatePinner; this.authenticator = okHttpClient.authenticator; @@ -343,6 +348,11 @@ public class OkHttpClient implements Cloneable { */ public OkHttpClient setSslSocketFactory(SSLSocketFactory sslSocketFactory) { this.sslSocketFactory = sslSocketFactory; + this.trustManager = Platform.get().trustManager(sslSocketFactory); + if (this.trustManager == null) { + throw new IllegalStateException("Unable to extract the trust manager on " + Platform.get() + + ", sslSocketFactory is " + sslSocketFactory.getClass()); + } return this; } @@ -589,6 +599,7 @@ public class OkHttpClient implements Cloneable { } if (result.sslSocketFactory == null) { result.sslSocketFactory = getDefaultSSLSocketFactory(); + result.trustManager = getDefaultTrustManager(); } if (result.hostnameVerifier == null) { result.hostnameVerifier = OkHostnameVerifier.INSTANCE; @@ -638,6 +649,17 @@ public class OkHttpClient implements Cloneable { return defaultSslSocketFactory; } + private synchronized X509TrustManager getDefaultTrustManager() { + if (defaultTrustManager == null) { + defaultTrustManager = Platform.get().trustManager(defaultSslSocketFactory); + if (defaultTrustManager == null) { + throw new IllegalStateException("Unable to extract the trust manager on " + Platform.get() + + ", sslSocketFactory is " + defaultSslSocketFactory.getClass()); + } + } + return defaultTrustManager; + } + /** Returns a shallow copy of this OkHttpClient. */ @Override public OkHttpClient clone() { return new OkHttpClient(this); diff --git a/okhttp/src/main/java/com/squareup/okhttp/internal/Platform.java b/okhttp/src/main/java/com/squareup/okhttp/internal/Platform.java index 4e578aef0..044ed0842 100644 --- a/okhttp/src/main/java/com/squareup/okhttp/internal/Platform.java +++ b/okhttp/src/main/java/com/squareup/okhttp/internal/Platform.java @@ -19,6 +19,7 @@ package com.squareup.okhttp.internal; import android.util.Log; import com.squareup.okhttp.Protocol; import java.io.IOException; +import java.lang.reflect.Field; import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -30,6 +31,8 @@ import java.util.ArrayList; import java.util.List; import java.util.logging.Level; import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.X509TrustManager; import okio.Buffer; import static com.squareup.okhttp.internal.Internal.logger; @@ -51,6 +54,11 @@ import static com.squareup.okhttp.internal.Internal.logger; * unstable. * * Supported on OpenJDK 7 and 8 (via the JettyALPN-boot library). + * + *

Trust Manager Extraction

+ * + *

Supported on Android 2.3+ and OpenJDK 7+. There are no public APIs to recover the trust + * manager that was used to create an {@link SSLSocketFactory}. */ public class Platform { private static final Platform PLATFORM = findPlatform(); @@ -74,6 +82,10 @@ public class Platform { public void untagSocket(Socket socket) throws SocketException { } + public X509TrustManager trustManager(SSLSocketFactory sslSocketFactory) { + return null; + } + /** * Configure TLS extensions on {@code sslSocket} for {@code route}. * @@ -109,11 +121,13 @@ public class Platform { private static Platform findPlatform() { // Attempt to find Android 2.3+ APIs. try { + Class sslParametersClass; try { - Class.forName("com.android.org.conscrypt.OpenSSLSocketImpl"); + sslParametersClass = Class.forName("com.android.org.conscrypt.SSLParametersImpl"); } catch (ClassNotFoundException e) { // Older platform before being unbundled. - Class.forName("org.apache.harmony.xnet.provider.jsse.OpenSSLSocketImpl"); + sslParametersClass = Class.forName( + "org.apache.harmony.xnet.provider.jsse.SSLParametersImpl"); } OptionalMethod setUseSessionTickets @@ -141,25 +155,34 @@ public class Platform { } catch (ClassNotFoundException | NoSuchMethodException ignored) { } - return new Android(setUseSessionTickets, setHostname, trafficStatsTagSocket, - trafficStatsUntagSocket, getAlpnSelectedProtocol, setAlpnProtocols); + return new Android(sslParametersClass, setUseSessionTickets, setHostname, + trafficStatsTagSocket, trafficStatsUntagSocket, getAlpnSelectedProtocol, + setAlpnProtocols); } catch (ClassNotFoundException ignored) { // This isn't an Android runtime. } - // Find Jetty's ALPN extension for OpenJDK. + // Find an Oracle JDK. try { - String negoClassName = "org.eclipse.jetty.alpn.ALPN"; - Class negoClass = Class.forName(negoClassName); - Class providerClass = Class.forName(negoClassName + "$Provider"); - Class clientProviderClass = Class.forName(negoClassName + "$ClientProvider"); - Class serverProviderClass = Class.forName(negoClassName + "$ServerProvider"); - Method putMethod = negoClass.getMethod("put", SSLSocket.class, providerClass); - Method getMethod = negoClass.getMethod("get", SSLSocket.class); - Method removeMethod = negoClass.getMethod("remove", SSLSocket.class); - return new JdkWithJettyBootPlatform( - putMethod, getMethod, removeMethod, clientProviderClass, serverProviderClass); - } catch (ClassNotFoundException | NoSuchMethodException ignored) { + Class sslContextClass = Class.forName("sun.security.ssl.SSLContextImpl"); + + // Find Jetty's ALPN extension for OpenJDK. + try { + String negoClassName = "org.eclipse.jetty.alpn.ALPN"; + Class negoClass = Class.forName(negoClassName); + Class providerClass = Class.forName(negoClassName + "$Provider"); + Class clientProviderClass = Class.forName(negoClassName + "$ClientProvider"); + Class serverProviderClass = Class.forName(negoClassName + "$ServerProvider"); + Method putMethod = negoClass.getMethod("put", SSLSocket.class, providerClass); + Method getMethod = negoClass.getMethod("get", SSLSocket.class); + Method removeMethod = negoClass.getMethod("remove", SSLSocket.class); + return new JdkWithJettyBootPlatform(sslContextClass, + putMethod, getMethod, removeMethod, clientProviderClass, serverProviderClass); + } catch (ClassNotFoundException | NoSuchMethodException ignored) { + } + + return new JdkPlatform(sslContextClass); + } catch (ClassNotFoundException ignored) { } return new Platform(); @@ -169,6 +192,7 @@ public class Platform { private static class Android extends Platform { private static final int MAX_LOG_LENGTH = 4000; + private final Class sslParametersClass; private final OptionalMethod setUseSessionTickets; private final OptionalMethod setHostname; @@ -180,9 +204,11 @@ public class Platform { private final OptionalMethod getAlpnSelectedProtocol; private final OptionalMethod setAlpnProtocols; - public Android(OptionalMethod setUseSessionTickets, OptionalMethod setHostname, - Method trafficStatsTagSocket, Method trafficStatsUntagSocket, - OptionalMethod getAlpnSelectedProtocol, OptionalMethod setAlpnProtocols) { + public Android(Class sslParametersClass, OptionalMethod setUseSessionTickets, + OptionalMethod setHostname, Method trafficStatsTagSocket, + Method trafficStatsUntagSocket, OptionalMethod getAlpnSelectedProtocol, + OptionalMethod setAlpnProtocols) { + this.sslParametersClass = sslParametersClass; this.setUseSessionTickets = setUseSessionTickets; this.setHostname = setHostname; this.trafficStatsTagSocket = trafficStatsTagSocket; @@ -207,6 +233,17 @@ public class Platform { } } + @Override public X509TrustManager trustManager(SSLSocketFactory sslSocketFactory) { + Object context = readFieldOrNull(sslSocketFactory, sslParametersClass, "sslParameters"); + if (context == null) return null; + + X509TrustManager x509TrustManager = readFieldOrNull( + context, X509TrustManager.class, "x509TrustManager"); + if (x509TrustManager != null) return x509TrustManager; + + return readFieldOrNull(context, X509TrustManager.class, "trustManager"); + } + @Override public void configureTlsExtensions( SSLSocket sslSocket, String hostname, List protocols) { // Enable SNI and session tickets. @@ -268,18 +305,34 @@ public class Platform { } } + /** JDK 1.7 or better. */ + private static class JdkPlatform extends Platform { + private final Class sslContextClass; + + public JdkPlatform(Class sslContextClass) { + this.sslContextClass = sslContextClass; + } + + @Override public X509TrustManager trustManager(SSLSocketFactory sslSocketFactory) { + Object context = readFieldOrNull(sslSocketFactory, sslContextClass, "context"); + if (context == null) return null; + return readFieldOrNull(context, X509TrustManager.class, "trustManager"); + } + } + /** * OpenJDK 7+ with {@code org.mortbay.jetty.alpn/alpn-boot} in the boot class path. */ - private static class JdkWithJettyBootPlatform extends Platform { + private static class JdkWithJettyBootPlatform extends JdkPlatform { private final Method putMethod; private final Method getMethod; private final Method removeMethod; private final Class clientProviderClass; private final Class serverProviderClass; - public JdkWithJettyBootPlatform(Method putMethod, Method getMethod, Method removeMethod, - Class clientProviderClass, Class serverProviderClass) { + public JdkWithJettyBootPlatform(Class sslContextClass, Method putMethod, Method getMethod, + Method removeMethod, Class clientProviderClass, Class serverProviderClass) { + super(sslContextClass); this.putMethod = putMethod; this.getMethod = getMethod; this.removeMethod = removeMethod; @@ -391,4 +444,27 @@ public class Platform { } return result.readByteArray(); } + + static T readFieldOrNull(Object instance, Class fieldType, String fieldName) { + for (Class c = instance.getClass(); c != Object.class; c = c.getSuperclass()) { + try { + Field field = c.getDeclaredField(fieldName); + field.setAccessible(true); + Object value = field.get(instance); + if (value == null || !fieldType.isInstance(value)) return null; + return fieldType.cast(value); + } catch (NoSuchFieldException ignored) { + } catch (IllegalAccessException e) { + throw new AssertionError(); + } + } + + // Didn't find the field we wanted. As a last gasp attempt, try to find the value on a delegate. + if (!fieldName.equals("delegate")) { + Object delegate = readFieldOrNull(instance, Object.class, "delegate"); + if (delegate != null) return readFieldOrNull(delegate, fieldType, fieldName); + } + + return null; + } }