diff --git a/okhttp-urlconnection/src/main/java/okhttp3/OkUrlFactory.java b/okhttp-urlconnection/src/main/java/okhttp3/OkUrlFactory.java index 09e18b5f5..973c70578 100644 --- a/okhttp-urlconnection/src/main/java/okhttp3/OkUrlFactory.java +++ b/okhttp-urlconnection/src/main/java/okhttp3/OkUrlFactory.java @@ -21,6 +21,7 @@ import java.net.URL; import java.net.URLConnection; import java.net.URLStreamHandler; import java.net.URLStreamHandlerFactory; +import okhttp3.internal.URLFilter; import okhttp3.internal.huc.HttpURLConnectionImpl; import okhttp3.internal.huc.HttpsURLConnectionImpl; @@ -31,6 +32,7 @@ import okhttp3.internal.huc.HttpsURLConnectionImpl; */ public final class OkUrlFactory implements URLStreamHandlerFactory, Cloneable { private OkHttpClient client; + private URLFilter urlFilter; public OkUrlFactory(OkHttpClient client) { this.client = client; @@ -45,6 +47,10 @@ public final class OkUrlFactory implements URLStreamHandlerFactory, Cloneable { return this; } + void setUrlFilter(URLFilter filter) { + urlFilter = filter; + } + /** * Returns a copy of this stream handler factory that includes a shallow copy of the internal * {@linkplain OkHttpClient HTTP client}. @@ -63,8 +69,8 @@ public final class OkUrlFactory implements URLStreamHandlerFactory, Cloneable { .proxy(proxy) .build(); - if (protocol.equals("http")) return new HttpURLConnectionImpl(url, copy); - if (protocol.equals("https")) return new HttpsURLConnectionImpl(url, copy); + if (protocol.equals("http")) return new HttpURLConnectionImpl(url, copy, urlFilter); + if (protocol.equals("https")) return new HttpsURLConnectionImpl(url, copy, urlFilter); throw new IllegalArgumentException("Unexpected protocol: " + protocol); } diff --git a/okhttp-urlconnection/src/main/java/okhttp3/internal/URLFilter.java b/okhttp-urlconnection/src/main/java/okhttp3/internal/URLFilter.java new file mode 100644 index 000000000..3b077f8bc --- /dev/null +++ b/okhttp-urlconnection/src/main/java/okhttp3/internal/URLFilter.java @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal; +import java.io.IOException; +import java.net.URL; + +/** + * Request filter based on the request's URL. + * + * @deprecated use {@link okhttp3.Interceptor} for non-HttpURLConnection filtering. + */ +public interface URLFilter { + /** + * Check whether request to the provided URL is permitted to be issued. + * + * @throws IOException if the request to the provided URL is not permitted. + */ + void checkURLPermitted(URL url) throws IOException; +} diff --git a/okhttp-urlconnection/src/main/java/okhttp3/internal/huc/HttpURLConnectionImpl.java b/okhttp-urlconnection/src/main/java/okhttp3/internal/huc/HttpURLConnectionImpl.java index fabedaa67..aa90bae4b 100644 --- a/okhttp-urlconnection/src/main/java/okhttp3/internal/huc/HttpURLConnectionImpl.java +++ b/okhttp-urlconnection/src/main/java/okhttp3/internal/huc/HttpURLConnectionImpl.java @@ -53,6 +53,7 @@ import okhttp3.Route; import okhttp3.internal.Internal; import okhttp3.internal.JavaNetHeaders; import okhttp3.internal.Platform; +import okhttp3.internal.URLFilter; import okhttp3.internal.Util; import okhttp3.internal.Version; import okhttp3.internal.http.HttpDate; @@ -107,11 +108,18 @@ public class HttpURLConnectionImpl extends HttpURLConnection { */ Handshake handshake; + private URLFilter urlFilter; + public HttpURLConnectionImpl(URL url, OkHttpClient client) { super(url); this.client = client; } + public HttpURLConnectionImpl(URL url, OkHttpClient client, URLFilter urlFilter) { + this(url, client); + this.urlFilter = urlFilter; + } + @Override public final void connect() throws IOException { initHttpEngine(); boolean success; @@ -456,6 +464,9 @@ public class HttpURLConnectionImpl extends HttpURLConnection { */ private boolean execute(boolean readResponse) throws IOException { boolean releaseConnection = true; + if (urlFilter != null) { + urlFilter.checkURLPermitted(httpEngine.getRequest().url().url()); + } try { httpEngine.sendRequest(); Connection connection = httpEngine.getConnection(); diff --git a/okhttp-urlconnection/src/main/java/okhttp3/internal/huc/HttpsURLConnectionImpl.java b/okhttp-urlconnection/src/main/java/okhttp3/internal/huc/HttpsURLConnectionImpl.java index 3d8b24db7..1bf1fa656 100644 --- a/okhttp-urlconnection/src/main/java/okhttp3/internal/huc/HttpsURLConnectionImpl.java +++ b/okhttp-urlconnection/src/main/java/okhttp3/internal/huc/HttpsURLConnectionImpl.java @@ -21,6 +21,7 @@ import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLSocketFactory; import okhttp3.Handshake; import okhttp3.OkHttpClient; +import okhttp3.internal.URLFilter; public final class HttpsURLConnectionImpl extends DelegatingHttpsURLConnection { private final HttpURLConnectionImpl delegate; @@ -29,6 +30,10 @@ public final class HttpsURLConnectionImpl extends DelegatingHttpsURLConnection { this(new HttpURLConnectionImpl(url, client)); } + public HttpsURLConnectionImpl(URL url, OkHttpClient client, URLFilter filter) { + this(new HttpURLConnectionImpl(url, client, filter)); + } + public HttpsURLConnectionImpl(HttpURLConnectionImpl delegate) { super(delegate); this.delegate = delegate; diff --git a/okhttp-urlconnection/src/test/java/okhttp3/OkUrlFactoryTest.java b/okhttp-urlconnection/src/test/java/okhttp3/OkUrlFactoryTest.java index 67be818b2..d357807ef 100644 --- a/okhttp-urlconnection/src/test/java/okhttp3/OkUrlFactoryTest.java +++ b/okhttp-urlconnection/src/test/java/okhttp3/OkUrlFactoryTest.java @@ -3,12 +3,17 @@ package okhttp3; import java.io.File; import java.io.IOException; import java.net.HttpURLConnection; +import java.net.URL; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.Date; import java.util.Locale; import java.util.TimeZone; import java.util.concurrent.TimeUnit; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLContext; +import okhttp3.internal.URLFilter; +import okhttp3.internal.SslContextBuilder; import okhttp3.internal.http.OkHeaders; import okhttp3.internal.io.InMemoryFileSystem; import okhttp3.mockwebserver.MockResponse; @@ -148,6 +153,66 @@ public class OkUrlFactoryTest { assertResponseCode(connection, 302); } + @Test + public void testURLFilter() throws Exception { + server.enqueue(new MockResponse() + .setBody("B")); + final URL blockedURL = server.url("/a").url(); + factory.setUrlFilter(new URLFilter() { + @Override + public void checkURLPermitted(URL url) throws IOException { + if (blockedURL.equals(url)) { + throw new IOException("Blocked"); + } + } + }); + try { + HttpURLConnection connection = factory.open(server.url("/a").url()); + connection.getInputStream(); + fail("Connection was successful"); + } catch (IOException e) { + assertEquals("Blocked", e.getMessage()); + } + HttpURLConnection connection = factory.open(server.url("/b").url()); + assertResponseBody(connection, "B"); + } + + @Test + public void testURLFilterRedirect() throws Exception { + MockWebServer cleartextServer = new MockWebServer(); + cleartextServer.enqueue(new MockResponse() + .setBody("Blocked!")); + final URL blockedURL = cleartextServer.url("/").url(); + + SSLContext context = SslContextBuilder.localhost(); + server.useHttps(context.getSocketFactory(), false); + factory.setClient(factory.client().newBuilder() + .sslSocketFactory(context.getSocketFactory()) + .followSslRedirects(true) + .build()); + factory.setUrlFilter(new URLFilter() { + @Override + public void checkURLPermitted(URL url) throws IOException { + if (blockedURL.equals(url)) { + throw new IOException("Blocked"); + } + } + }); + + server.enqueue(new MockResponse() + .setResponseCode(302) + .addHeader("Location: " + blockedURL) + .setBody("This page has moved")); + URL destination = server.url("/").url(); + try { + HttpsURLConnection httpsConnection = (HttpsURLConnection) factory.open(destination); + httpsConnection.getInputStream(); + fail("Connection was successful"); + } catch (IOException e) { + assertEquals("Blocked", e.getMessage()); + } + } + private void assertResponseBody(HttpURLConnection connection, String expected) throws Exception { BufferedSource source = buffer(source(connection.getInputStream())); String actual = source.readString(US_ASCII);