From 46fcf4f42c6471c4d13e5f2ba0eae6fdebe5d1bd Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Thu, 2 Apr 2020 12:58:26 +0100 Subject: [PATCH] Fix for infinite caching in DoH (#5918) Hit the cache only initially, but avoid using stale cached data. --- .../java/okhttp3/dnsoverhttps/DnsOverHttps.kt | 10 ++++-- .../dnsoverhttps/DnsOverHttpsTest.java | 36 +++++++++++++++++++ .../okhttp3/dnsoverhttps/DohProviders.java | 15 ++++---- 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsOverHttps.kt b/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsOverHttps.kt index 0aabaf5a9..bffafe6b5 100644 --- a/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsOverHttps.kt +++ b/okhttp-dnsoverhttps/src/main/java/okhttp3/dnsoverhttps/DnsOverHttps.kt @@ -16,6 +16,7 @@ package okhttp3.dnsoverhttps import java.io.IOException +import java.net.HttpURLConnection import java.net.InetAddress import java.net.UnknownHostException import java.util.ArrayList @@ -185,11 +186,16 @@ class DnsOverHttps internal constructor( private fun getCacheOnlyResponse(request: Request): Response? { if (!post && client.cache != null) { try { - val cacheRequest = request.newBuilder().cacheControl(CacheControl.FORCE_CACHE).build() + // Use the cache without hitting the network first + // 504 code indicates that the Cache is stale + val preferCache = CacheControl.Builder() + .onlyIfCached() + .build() + val cacheRequest = request.newBuilder().cacheControl(preferCache).build() val cacheResponse = client.newCall(cacheRequest).execute() - if (cacheResponse.code != 504) { + if (cacheResponse.code != HttpURLConnection.HTTP_GATEWAY_TIMEOUT) { return cacheResponse } } catch (ioe: IOException) { diff --git a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.java b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.java index 8b750f018..cef2155af 100644 --- a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.java +++ b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.java @@ -22,6 +22,7 @@ import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Arrays; import java.util.List; +import java.util.concurrent.TimeUnit; import okhttp3.Cache; import okhttp3.Dns; import okhttp3.HttpUrl; @@ -177,6 +178,41 @@ public class DnsOverHttpsTest { assertThat(result).isEqualTo(singletonList(address("157.240.1.18"))); } + @Test public void usesCacheOnlyIfFresh() throws Exception { + Cache cache = new Cache(new File("./target/DnsOverHttpsTest.cache"), 100 * 1024); + OkHttpClient cachedClient = bootstrapClient.newBuilder().cache(cache).build(); + DnsOverHttps cachedDns = buildLocalhost(cachedClient, false); + + server.enqueue(dnsResponse( + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c00050001" + + "00000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010" + + "0010000003b00049df00112").setHeader("cache-control", "max-age=1")); + + List result = cachedDns.lookup("google.com"); + + assertThat(result).containsExactly(address("157.240.1.18")); + + RecordedRequest recordedRequest = server.takeRequest(0, TimeUnit.SECONDS); + assertThat(recordedRequest.getMethod()).isEqualTo("GET"); + assertThat(recordedRequest.getPath()).isEqualTo( + "/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ"); + + Thread.sleep(2000); + + server.enqueue(dnsResponse( + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c00050001" + + "00000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010" + + "0010000003b00049df00112").setHeader("cache-control", "max-age=1")); + + result = cachedDns.lookup("google.com"); + assertThat(result).isEqualTo(singletonList(address("157.240.1.18"))); + + recordedRequest = server.takeRequest(0, TimeUnit.SECONDS); + assertThat(recordedRequest.getMethod()).isEqualTo("GET"); + assertThat(recordedRequest.getPath()).isEqualTo( + "/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ"); + } + private MockResponse dnsResponse(String s) { return new MockResponse().setBody(new Buffer().write(ByteString.decodeHex(s))) .addHeader("content-type", "application/dns-message") diff --git a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DohProviders.java b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DohProviders.java index 301eee57c..d11459061 100644 --- a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DohProviders.java +++ b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DohProviders.java @@ -30,15 +30,15 @@ import okhttp3.OkHttpClient; public class DohProviders { static DnsOverHttps buildGoogle(OkHttpClient bootstrapClient) { return new DnsOverHttps.Builder().client(bootstrapClient) - .url(HttpUrl.get("https://dns.google.com/experimental")) - .bootstrapDnsHosts(getByIp("216.58.204.78"), getByIp("2a00:1450:4009:814:0:0:0:200e")) + .url(HttpUrl.get("https://dns.google/dns-query")) + .bootstrapDnsHosts(getByIp("8.8.4.4"), getByIp("8.8.8.8")) .build(); } static DnsOverHttps buildGooglePost(OkHttpClient bootstrapClient) { return new DnsOverHttps.Builder().client(bootstrapClient) - .url(HttpUrl.get("https://dns.google.com/experimental")) - .bootstrapDnsHosts(getByIp("216.58.204.78"), getByIp("2a00:1450:4009:814:0:0:0:200e")) + .url(HttpUrl.get("https://dns.google/dns-query")) + .bootstrapDnsHosts(getByIp("8.8.4.4"), getByIp("8.8.8.8")) .post(true) .build(); } @@ -52,8 +52,8 @@ public class DohProviders { static DnsOverHttps buildCloudflare(OkHttpClient bootstrapClient) { return new DnsOverHttps.Builder().client(bootstrapClient) - .url(HttpUrl.get("https://cloudflare-dns.com/dns-query")) - .bootstrapDnsHosts(getByIp("1.1.1.1")) + .url(HttpUrl.get("https://1.1.1.1/dns-query")) + .bootstrapDnsHosts(getByIp("1.1.1.1"), getByIp("1.0.0.1")) .includeIPv6(false) .build(); } @@ -61,8 +61,7 @@ public class DohProviders { static DnsOverHttps buildCloudflarePost(OkHttpClient bootstrapClient) { return new DnsOverHttps.Builder().client(bootstrapClient) .url(HttpUrl.get("https://cloudflare-dns.com/dns-query")) - .bootstrapDnsHosts(getByIp("104.16.111.25"), getByIp("104.16.112.25"), - getByIp("2400:cb00:2048:1:0:0:6810:7019"), getByIp("2400:cb00:2048:1:0:0:6810:6f19")) + .bootstrapDnsHosts(getByIp("1.1.1.1"), getByIp("1.0.0.1")) .includeIPv6(false) .post(true) .build();