From 0b2486f7f32fbe21c8630f330d240546a15beb66 Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Sun, 2 Sep 2018 19:27:32 +0100 Subject: [PATCH] Make separate IPv4 and IPv6 requests for DNS over HTTPS (#4234) Make separate requests because DNS in practice does not support multiple questions A + AAAA in a single message. --- .../okhttp3/dnsoverhttps/DnsOverHttps.java | 141 ++++++++++++++---- .../okhttp3/dnsoverhttps/DnsRecordCodec.java | 16 +- .../dnsoverhttps/DnsOverHttpsTest.java | 60 +++++--- .../dnsoverhttps/DnsRecordCodecTest.java | 12 +- 4 files changed, 162 insertions(+), 67 deletions(-) diff --git a/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsOverHttps.java b/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsOverHttps.java index c5132bf69..67442f331 100644 --- a/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsOverHttps.java +++ b/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsOverHttps.java @@ -18,11 +18,14 @@ package okhttp3.dnsoverhttps; import java.io.IOException; import java.net.InetAddress; import java.net.UnknownHostException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.CountDownLatch; import javax.annotation.Nullable; import okhttp3.CacheControl; +import okhttp3.Call; +import okhttp3.Callback; import okhttp3.Dns; import okhttp3.HttpUrl; import okhttp3.MediaType; @@ -32,6 +35,7 @@ import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; import okhttp3.ResponseBody; +import okhttp3.internal.Util; import okhttp3.internal.platform.Platform; import okhttp3.internal.publicsuffix.PublicSuffixDatabase; import okio.ByteString; @@ -113,8 +117,6 @@ public class DnsOverHttps implements Dns { } @Override public List lookup(String hostname) throws UnknownHostException { - UnknownHostException firstUhe = null; - if (!resolvePrivateAddresses || !resolvePublicAddresses) { boolean privateHost = isPrivateHost(hostname); @@ -131,37 +133,117 @@ public class DnsOverHttps implements Dns { } private List lookupHttps(String hostname) throws UnknownHostException { - try { - ByteString query = DnsRecordCodec.encodeQuery(hostname, includeIPv6); + List networkRequests = new ArrayList<>(2); + List failures = new ArrayList<>(2); + List results = new ArrayList<>(5); - Request request = buildRequest(query); - Response response = executeRequest(request); + buildRequest(hostname, networkRequests, results, failures, DnsRecordCodec.TYPE_A); - return readResponse(hostname, response); - } catch (UnknownHostException uhe) { - throw uhe; - } catch (Exception e) { - UnknownHostException unknownHostException = new UnknownHostException(hostname); - unknownHostException.initCause(e); - throw unknownHostException; + if (includeIPv6) { + buildRequest(hostname, networkRequests, results, failures, DnsRecordCodec.TYPE_AAAA); + } + + executeRequests(hostname, networkRequests, results, failures); + + if (!results.isEmpty()) { + return results; + } + + return throwBestFailure(hostname, failures); + } + + private void buildRequest(String hostname, List networkRequests, List results, + List failures, int type) { + Request request = buildRequest(hostname, type); + Response response = getCacheOnlyResponse(request); + + if (response != null) { + processResponse(response, hostname, results, failures); + } else { + networkRequests.add(client.newCall(request)); } } - private Response executeRequest(Request request) throws IOException { - // cached request + private void executeRequests(final String hostname, List networkRequests, + final List responses, final List failures) { + final CountDownLatch latch = new CountDownLatch(networkRequests.size()); + + for (Call call : networkRequests) { + call.enqueue(new Callback() { + @Override public void onFailure(Call call, IOException e) { + synchronized (failures) { + failures.add(e); + } + latch.countDown(); + } + + @Override public void onResponse(Call call, Response response) { + processResponse(response, hostname, responses, failures); + latch.countDown(); + } + }); + } + + try { + latch.await(); + } catch (InterruptedException e) { + failures.add(e); + } + } + + private void processResponse(Response response, String hostname, List results, + List failures) { + try { + List addresses = readResponse(hostname, response); + synchronized (results) { + results.addAll(addresses); + } + } catch (Exception e) { + synchronized (failures) { + failures.add(e); + } + } + } + + private List throwBestFailure(String hostname, List failures) + throws UnknownHostException { + if (failures.size() == 0) { + throw new UnknownHostException(hostname); + } + + Exception failure = failures.get(0); + + if (failure instanceof UnknownHostException) { + throw (UnknownHostException) failure; + } + + UnknownHostException unknownHostException = new UnknownHostException(hostname); + unknownHostException.initCause(failure); + + for (int i = 1; i < failures.size(); i++) { + Util.addSuppressedIfPossible(unknownHostException, failures.get(i)); + } + + throw unknownHostException; + } + + private @Nullable Response getCacheOnlyResponse(Request request) { if (!post && client.cache() != null) { - CacheControl cacheControl = - new CacheControl.Builder().maxStale(Integer.MAX_VALUE, TimeUnit.SECONDS).build(); - Request cacheRequest = request.newBuilder().cacheControl(cacheControl).build(); + try { + Request cacheRequest = request.newBuilder().cacheControl(CacheControl.FORCE_CACHE).build(); - Response response = client.newCall(cacheRequest).execute(); + Response cacheResponse = client.newCall(cacheRequest).execute(); - if (response.isSuccessful()) { - return response; + if (cacheResponse.code() != 504) { + return cacheResponse; + } + } catch (IOException ioe) { + // Failures are ignored as we can fallback to the network + // and hopefully repopulate the cache. } } - return client.newCall(request).execute(); + return null; } private List readResponse(String hostname, Response response) throws Exception { @@ -192,9 +274,11 @@ public class DnsOverHttps implements Dns { } } - private Request buildRequest(ByteString query) { + private Request buildRequest(String hostname, int type) { Request.Builder requestBuilder = new Request.Builder().header("Accept", DNS_MESSAGE.toString()); + ByteString query = DnsRecordCodec.encodeQuery(hostname, type); + if (post) { requestBuilder = requestBuilder.url(url).post(RequestBody.create(DNS_MESSAGE, query)); } else { @@ -216,12 +300,14 @@ public class DnsOverHttps implements Dns { @Nullable HttpUrl url = null; boolean includeIPv6 = true; boolean post = false; - MediaType contentType = DNS_MESSAGE; Dns systemDns = Dns.SYSTEM; @Nullable List bootstrapDnsHosts = null; boolean resolvePrivateAddresses = false; boolean resolvePublicAddresses = true; + public Builder() { + } + public DnsOverHttps build() { return new DnsOverHttps(this); } @@ -246,11 +332,6 @@ public class DnsOverHttps implements Dns { return this; } - public Builder contentType(MediaType contentType) { - this.contentType = contentType; - return this; - } - public Builder resolvePrivateAddresses(boolean resolvePrivateAddresses) { this.resolvePrivateAddresses = resolvePrivateAddresses; return this; diff --git a/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsRecordCodec.java b/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsRecordCodec.java index 271a05446..4c1cfc481 100644 --- a/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsRecordCodec.java +++ b/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsRecordCodec.java @@ -31,20 +31,20 @@ import okio.Utf8; class DnsRecordCodec { private static final byte SERVFAIL = 2; private static final byte NXDOMAIN = 3; - private static final int TYPE_A = 0x0001; - private static final int TYPE_AAAA = 0x001c; + public static final int TYPE_A = 0x0001; + public static final int TYPE_AAAA = 0x001c; private static final int TYPE_PTR = 0x000c; private static final Charset ASCII = Charset.forName("ASCII"); private DnsRecordCodec() { } - public static ByteString encodeQuery(String host, boolean includeIPv6) { + public static ByteString encodeQuery(String host, int type) { Buffer buf = new Buffer(); buf.writeShort(0); // query id buf.writeShort(256); // flags with recursion - buf.writeShort(includeIPv6 ? 2 : 1); // question count + buf.writeShort(1); // question count buf.writeShort(0); // answerCount buf.writeShort(0); // authorityResourceCount buf.writeShort(0); // additional @@ -62,15 +62,9 @@ class DnsRecordCodec { nameBuf.writeByte(0); // end nameBuf.copyTo(buf, 0, nameBuf.size()); - buf.writeShort(TYPE_A); + buf.writeShort(type); buf.writeShort(1); // CLASS_IN - if (includeIPv6) { - nameBuf.copyTo(buf, 0, nameBuf.size()); - buf.writeShort(TYPE_AAAA); - buf.writeShort(1); // CLASS_IN - } - return buf.readByteString(); } diff --git a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.java b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.java index 31c8da16c..c6c16278d 100644 --- a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.java +++ b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.java @@ -20,6 +20,8 @@ import java.io.IOException; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Arrays; +import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import okhttp3.Cache; import okhttp3.Dns; @@ -35,6 +37,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -44,9 +47,8 @@ public class DnsOverHttpsTest { @Rule public final MockWebServer server = new MockWebServer(); private final OkHttpClient bootstrapClient = - new OkHttpClient.Builder().protocols(Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1)) - .build(); - private final Dns dns = buildLocalhost(bootstrapClient); + new OkHttpClient.Builder().protocols(asList(Protocol.HTTP_2, Protocol.HTTP_1_1)).build(); + private Dns dns = buildLocalhost(bootstrapClient, false); @Before public void setUp() { server.setProtocols(bootstrapClient.protocols()); @@ -64,24 +66,38 @@ public class DnsOverHttpsTest { RecordedRequest recordedRequest = server.takeRequest(); assertEquals("GET", recordedRequest.getMethod()); - assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t" - + "AAAcAAE", recordedRequest.getPath()); + assertEquals("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ", + recordedRequest.getPath()); } @Test public void getIpv6() throws Exception { + server.enqueue(dnsResponse( + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c00050001" + + "00000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010" + + "0010000003b00049df00112")); server.enqueue(dnsResponse( "0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c00050001" + "00000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c0" + "0010000003b00102a032880f0290011faceb00c00000002")); + dns = buildLocalhost(bootstrapClient, true); + List result = dns.lookup("google.com"); - assertEquals(singletonList(address("2a03:2880:f029:11:face:b00c:0:2")), result); + assertEquals(2, result.size()); + assertTrue(result.contains(address("157.240.1.18"))); + assertTrue(result.contains(address("2a03:2880:f029:11:face:b00c:0:2"))); - RecordedRequest recordedRequest = server.takeRequest(); - assertEquals("GET", recordedRequest.getMethod()); - assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t" - + "AAAcAAE", recordedRequest.getPath()); + RecordedRequest request1 = server.takeRequest(); + assertEquals("GET", request1.getMethod()); + + RecordedRequest request2 = server.takeRequest(); + assertEquals("GET", request2.getMethod()); + + assertEquals(new HashSet<>( + Arrays.asList("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ", + "/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ")), + new LinkedHashSet<>(Arrays.asList(request1.getPath(), request2.getPath()))); } @Test public void failure() throws Exception { @@ -100,8 +116,8 @@ public class DnsOverHttpsTest { RecordedRequest recordedRequest = server.takeRequest(); assertEquals("GET", recordedRequest.getMethod()); - assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t" - + "AAAcAAE", recordedRequest.getPath()); + assertEquals("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ", + recordedRequest.getPath()); } @Test public void failOnExcessiveResponse() { @@ -145,13 +161,12 @@ public class DnsOverHttpsTest { @Test public void usesCache() throws Exception { Cache cache = new Cache(new File("./target/DnsOverHttpsTest.cache"), 100 * 1024); OkHttpClient cachedClient = bootstrapClient.newBuilder().cache(cache).build(); - DnsOverHttps cachedDns = buildLocalhost(cachedClient); + DnsOverHttps cachedDns = buildLocalhost(cachedClient, false); server.enqueue(dnsResponse( "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c00050001" + "00000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010" - + "0010000003b00049df00112") - .setHeader("cache-control", "private, max-age=298")); + + "0010000003b00049df00112").setHeader("cache-control", "private, max-age=298")); List result = cachedDns.lookup("google.com"); @@ -159,23 +174,26 @@ public class DnsOverHttpsTest { RecordedRequest recordedRequest = server.takeRequest(); assertEquals("GET", recordedRequest.getMethod()); - assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t" - + "AAAcAAE", recordedRequest.getPath()); + assertEquals("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ", + recordedRequest.getPath()); result = cachedDns.lookup("google.com"); assertEquals(singletonList(address("157.240.1.18")), result); } private MockResponse dnsResponse(String s) { - return new MockResponse() - .setBody(new Buffer().write(ByteString.decodeHex(s))) + return new MockResponse().setBody(new Buffer().write(ByteString.decodeHex(s))) .addHeader("content-type", "application/dns-message") .addHeader("content-length", s.length() / 2); } - private DnsOverHttps buildLocalhost(OkHttpClient bootstrapClient) { + private DnsOverHttps buildLocalhost(OkHttpClient bootstrapClient, boolean includeIPv6) { HttpUrl url = server.url("/lookup?ct"); - return new DnsOverHttps.Builder().client(bootstrapClient).resolvePrivateAddresses(true).url(url).build(); + return new DnsOverHttps.Builder().client(bootstrapClient) + .includeIPv6(includeIPv6) + .resolvePrivateAddresses(true) + .url(url) + .build(); } private static InetAddress address(String host) { diff --git a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsRecordCodecTest.java b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsRecordCodecTest.java index 32700c266..3a12b4d25 100644 --- a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsRecordCodecTest.java +++ b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsRecordCodecTest.java @@ -22,24 +22,26 @@ import java.util.List; import okio.ByteString; import org.junit.Test; +import static okhttp3.dnsoverhttps.DnsRecordCodec.TYPE_A; +import static okhttp3.dnsoverhttps.DnsRecordCodec.TYPE_AAAA; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; public class DnsRecordCodecTest { @Test public void testGoogleDotComEncoding() { - String encoded = encodeQuery("google.com", false); + String encoded = encodeQuery("google.com", TYPE_A); assertEquals("AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ", encoded); } - private String encodeQuery(String host, boolean includeIpv6) { - return DnsRecordCodec.encodeQuery(host, includeIpv6).base64Url().replace("=", ""); + private String encodeQuery(String host, int type) { + return DnsRecordCodec.encodeQuery(host, type).base64Url().replace("=", ""); } @Test public void testGoogleDotComEncodingWithIPv6() { - String encoded = encodeQuery("google.com", true); + String encoded = encodeQuery("google.com", TYPE_AAAA); - assertEquals("AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29tAAAcAAE", encoded); + assertEquals("AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ", encoded); } @Test public void testGoogleDotComDecodingFromCloudflare() throws Exception {