1
0
mirror of https://github.com/esp8266/Arduino.git synced 2025-07-27 18:02:17 +03:00

Add DNS forwarder to DNSServer (#7237)

The key functions added are:

`bool enableForwarder(const String &domainName=emptyString, const IPAddress &dns=uint32_t)0)`

If specified, `enableForwarder` will update the `domainName` that is used to match DNS request to this AP's IP Address. A non-matching request will be forwarded to the DNS server specified by `dns`. 

Returns `true` on success.

Returns `false`, 
 * when forwarding `dns` is not set, or 
 * unable to allocate resources for managing the DNS forward function.

`void disableForwarder(const String &domainName=emptyString, bool freeResources=false)`

`disableForwarder` will stop forwarding DNS requests. If specified, updates the `domainName` that is matched for returning this AP's IP Address.
Optionally, resources used for the DNS forward function can be freed.
This commit is contained in:
M Hightower
2022-05-08 04:04:34 -07:00
committed by GitHub
parent 1a49a0449b
commit bcb5464167
8 changed files with 1432 additions and 47 deletions

View File

@ -1,33 +1,139 @@
#include <ESP8266WiFi.h>
#include "DNSServer.h"
#include <lwip/def.h>
#include <Arduino.h>
#include <memory>
extern struct rst_info resetInfo;
#ifdef DEBUG_ESP_PORT
#define DEBUG_OUTPUT DEBUG_ESP_PORT
#define CONSOLE DEBUG_ESP_PORT
#else
#define DEBUG_OUTPUT Serial
#define CONSOLE Serial
#endif
#define _ETS_PRINTF(a, ...) ets_uart_printf(a, ##__VA_ARGS__)
#define _ETS_PRINTFNL(a, ...) ets_uart_printf(a "\n", ##__VA_ARGS__)
#define _PRINTF(a, ...) printf_P(PSTR(a), ##__VA_ARGS__)
#define _PRINT(a) print(String(F(a)))
#define _PRINTLN(a) println(String(F(a)))
#define _PRINTLN2(a, b) println(String(F(a)) + b )
#define ETS_PRINTF _ETS_PRINTF
#define ETS_PRINTFNL _ETS_PRINTFNL
#define CONSOLE_PRINTF CONSOLE._PRINTF
#define CONSOLE_PRINT CONSOLE._PRINT
#define CONSOLE_PRINTLN CONSOLE._PRINTLN
#define CONSOLE_PRINTLN2 CONSOLE._PRINTLN2
#ifdef DEBUG_DNSSERVER
#define DEBUG_PRINTF CONSOLE_PRINTF
#define DEBUG_PRINT CONSOLE_PRINT
#define DEBUG_PRINTLN CONSOLE_PRINTLN
#define DEBUG_PRINTLN2 CONSOLE_PRINTLN2
#define DBGLOG_FAIL LOG_FAIL
#define DEBUG_(...) do { (__VA_ARGS__); } while(false)
#define DEBUG__(...) __VA_ARGS__
#define LOG_FAIL(a, fmt, ...) do { if (!(a)) { CONSOLE.printf_P( PSTR(fmt " line: %d, function: %s\r\n"), ##__VA_ARGS__, __LINE__, __FUNCTION__ ); } } while(false);
#else
#define DEBUG_PRINTF(...) do { } while(false)
#define DEBUG_PRINT(...) do { } while(false)
#define DEBUG_PRINTLN(...) do { } while(false)
#define DEBUG_PRINTLN2(...) do { } while(false)
#define DEBUG_(...) do { } while(false)
#define DEBUG__(...) do { } while(false)
#define LOG_FAIL(a, ...) do { a; } while(false)
#define DBGLOG_FAIL(...) do { } while(false)
#endif
#define DNS_HEADER_SIZE sizeof(DNSHeader)
// Want to keep IDs unique across restarts and continquious
static uint32_t _ids __attribute__((section(".noinit")));
DNSServer::DNSServer()
{
// I have observed that using 0 for captive and non-zero (600) when
// forwarding, will help Android devices recognize the change in connectivity.
// They will then report connected.
_ttl = lwip_htonl(60);
if (REASON_DEFAULT_RST == resetInfo.reason ||
REASON_DEEP_SLEEP_AWAKE <= resetInfo.reason) {
_ids = random(0, BIT(16) - 1);
}
_ids += kDNSSQueSize; // for the case of restart, ignore any inflight responses
_errorReplyCode = DNSReplyCode::NonExistentDomain;
}
bool DNSServer::start(const uint16_t &port, const String &domainName,
const IPAddress &resolvedIP)
void DNSServer::disableForwarder(const String &domainName, bool freeResources)
{
_port = port;
_domainName = domainName;
_forwarder = false;
if (!domainName.isEmpty()) {
_domainName = domainName;
downcaseAndRemoveWwwPrefix(_domainName);
}
if (freeResources) {
_dns = (uint32_t)0;
if (_que) {
_que = nullptr;
DEBUG_PRINTF("from stop, deleted _que\r\n");
DEBUG_(({
if (_que_ov) {
DEBUG_PRINTLN2("DNS forwarder que overflow or no reply to request: ", (_que_ov));
}
if (_que_drop) {
DEBUG_PRINTLN2("DNS forwarder que wrapped, reply dropped: ", (_que_drop));
}
}));
}
}
}
bool DNSServer::enableForwarder(const String &domainName, const IPAddress &dns)
{
disableForwarder(domainName, false); // Just happens to have the same logic needed here.
if (dns.isSet()) {
_dns = dns;
}
if (_dns.isSet()) {
if (!_que) {
_que = std::unique_ptr<DNSS_REQUESTER[]> (new (std::nothrow) DNSS_REQUESTER[kDNSSQueSize]);
DEBUG_PRINTF("Created new _que\r\n");
if (_que) {
for (size_t i = 0; i < kDNSSQueSize; i++) {
_que[i].ip = 0;
}
DEBUG_((_que_ov = 0));
DEBUG_((_que_drop = 0));
}
}
if (_que) {
_forwarder = true;
}
}
return _forwarder;
}
bool DNSServer::start(const uint16_t &port, const String &domainName,
const IPAddress &resolvedIP, const IPAddress &dns)
{
_port = (port) ? port : IANA_DNS_PORT;
_resolvedIP[0] = resolvedIP[0];
_resolvedIP[1] = resolvedIP[1];
_resolvedIP[2] = resolvedIP[2];
_resolvedIP[3] = resolvedIP[3];
downcaseAndRemoveWwwPrefix(_domainName);
if (!enableForwarder(domainName, dns) && (dns.isSet() || _dns.isSet())) {
return false;
}
return _udp.begin(_port) == 1;
}
@ -41,9 +147,15 @@ void DNSServer::setTTL(const uint32_t &ttl)
_ttl = lwip_htonl(ttl);
}
uint32_t DNSServer::getTTL()
{
return lwip_ntohl(_ttl);
}
void DNSServer::stop()
{
_udp.stop();
disableForwarder(emptyString, true);
}
void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName)
@ -53,7 +165,58 @@ void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName)
domainName.remove(0, 4);
}
void DNSServer::respondToRequest(uint8_t *buffer, size_t length)
void DNSServer::forwardReply(uint8_t *buffer, size_t length)
{
if (!_forwarder || !_que) {
return;
}
DNSHeader *dnsHeader = (DNSHeader *)buffer;
uint16_t id = dnsHeader->ID;
// if (kDNSSQueSize <= (uint16_t)((uint16_t)_ids - id)) {
if ((uint16_t)kDNSSQueSize <= (uint16_t)_ids - id) {
DEBUG_((++_que_drop));
DEBUG_PRINTLN2("Forward reply ID: 0x", (String(id, HEX) + F(" dropped!")));
return;
}
size_t i = id & (kDNSSQueSize - 1);
// Drop duplicate packets
if (0 == _que[i].ip) {
DEBUG_PRINTLN2("Duplicate reply dropped ID: 0x", String(id, HEX));
return;
}
dnsHeader->ID = _que[i].id;
_udp.beginPacket(_que[i].ip, _que[i].port);
_udp.write(buffer, length);
_udp.endPacket();
DEBUG_PRINTLN2("Forward reply ID: 0x", (String(id, HEX) + F(" to ") + IPAddress(_que[i].ip).toString()));
_que[i].ip = 0; // This gets used to detect duplicate packets and overflow
}
void DNSServer::forwardRequest(uint8_t *buffer, size_t length)
{
if (!_forwarder || !_dns.isSet() || !_que) {
return;
}
DNSHeader *dnsHeader = (DNSHeader *)buffer;
++_ids;
size_t i = _ids & (kDNSSQueSize - 1);
DEBUG_(({
if (0 != _que[i].ip) {
++_que_ov;
}
}));
_que[i].ip = _udp.remoteIP();
_que[i].port = _udp.remotePort();
_que[i].id = dnsHeader->ID;
dnsHeader->ID = (uint16_t)_ids;
_udp.beginPacket(_dns, IANA_DNS_PORT);
_udp.write(buffer, length);
_udp.endPacket();
DEBUG_PRINTLN2("Forward request ID: 0x", (String(dnsHeader->ID, HEX) + F(" to ") + _dns.toString()));
}
bool DNSServer::respondToRequest(uint8_t *buffer, size_t length)
{
DNSHeader *dnsHeader;
uint8_t *query, *start;
@ -64,23 +227,30 @@ void DNSServer::respondToRequest(uint8_t *buffer, size_t length)
dnsHeader = (DNSHeader *)buffer;
// Must be a query for us to do anything with it
if (dnsHeader->QR != DNS_QR_QUERY)
return;
if (dnsHeader->QR != DNS_QR_QUERY) {
return false;
}
// If operation is anything other than query, we don't do it
if (dnsHeader->OPCode != DNS_OPCODE_QUERY)
return replyWithError(dnsHeader, DNSReplyCode::NotImplemented);
if (dnsHeader->OPCode != DNS_OPCODE_QUERY) {
replyWithError(dnsHeader, DNSReplyCode::NotImplemented);
return false;
}
// Only support requests containing single queries - everything else
// is badly defined
if (dnsHeader->QDCount != lwip_htons(1))
return replyWithError(dnsHeader, DNSReplyCode::FormError);
if (dnsHeader->QDCount != lwip_htons(1)) {
replyWithError(dnsHeader, DNSReplyCode::FormError);
return false;
}
// We must return a FormError in the case of a non-zero ARCount to
// be minimally compatible with EDNS resolvers
if (dnsHeader->ANCount != 0 || dnsHeader->NSCount != 0
|| dnsHeader->ARCount != 0)
return replyWithError(dnsHeader, DNSReplyCode::FormError);
|| dnsHeader->ARCount != 0) {
replyWithError(dnsHeader, DNSReplyCode::FormError);
return false;
}
// Even if we're not going to use the query, we need to parse it
// so we can check the address type that's being queried
@ -89,15 +259,19 @@ void DNSServer::respondToRequest(uint8_t *buffer, size_t length)
remaining = length - DNS_HEADER_SIZE;
while (remaining != 0 && *start != 0) {
labelLength = *start;
if (labelLength + 1 > remaining)
return replyWithError(dnsHeader, DNSReplyCode::FormError);
if (labelLength + 1 > remaining) {
replyWithError(dnsHeader, DNSReplyCode::FormError);
return false;
}
remaining -= (labelLength + 1);
start += (labelLength + 1);
}
// 1 octet labelLength, 2 octet qtype, 2 octet qclass
if (remaining < 5)
return replyWithError(dnsHeader, DNSReplyCode::FormError);
if (remaining < 5) {
replyWithError(dnsHeader, DNSReplyCode::FormError);
return false;
}
start += 1; // Skip the 0 length label that we found above
@ -109,23 +283,33 @@ void DNSServer::respondToRequest(uint8_t *buffer, size_t length)
queryLength = start - query;
if (qclass != lwip_htons(DNS_QCLASS_ANY)
&& qclass != lwip_htons(DNS_QCLASS_IN))
return replyWithError(dnsHeader, DNSReplyCode::NonExistentDomain,
query, queryLength);
&& qclass != lwip_htons(DNS_QCLASS_IN)) {
replyWithError(dnsHeader, DNSReplyCode::NonExistentDomain, query, queryLength);
return false;
}
if (qtype != lwip_htons(DNS_QTYPE_A)
&& qtype != lwip_htons(DNS_QTYPE_ANY))
return replyWithError(dnsHeader, DNSReplyCode::NonExistentDomain,
query, queryLength);
&& qtype != lwip_htons(DNS_QTYPE_ANY)) {
replyWithError(dnsHeader, DNSReplyCode::NonExistentDomain, query, queryLength);
return false;
}
// If we have no domain name configured, just return an error
if (_domainName.isEmpty())
return replyWithError(dnsHeader, _errorReplyCode,
query, queryLength);
if (_domainName.isEmpty()) {
if (_forwarder) {
return true;
} else {
replyWithError(dnsHeader, _errorReplyCode, query, queryLength);
return false;
}
}
// If we're running with a wildcard we can just return a result now
if (_domainName == "*")
return replyWithIP(dnsHeader, query, queryLength);
if (_domainName == "*") {
DEBUG_PRINTF("dnsServer - replyWithIP\r\n");
replyWithIP(dnsHeader, query, queryLength);
return false;
}
matchString = _domainName.c_str();
@ -139,24 +323,32 @@ void DNSServer::respondToRequest(uint8_t *buffer, size_t length)
labelLength = *start;
start += 1;
while (labelLength > 0) {
if (tolower(*start) != *matchString)
return replyWithError(dnsHeader, _errorReplyCode,
query, queryLength);
if (tolower(*start) != *matchString) {
if (_forwarder) {
return true;
} else {
replyWithError(dnsHeader, _errorReplyCode, query, queryLength);
return false;
}
}
++start;
++matchString;
--labelLength;
}
if (*start == 0 && *matchString == '\0')
return replyWithIP(dnsHeader, query, queryLength);
if (*start == 0 && *matchString == '\0') {
replyWithIP(dnsHeader, query, queryLength);
return false;
}
if (*matchString != '.')
return replyWithError(dnsHeader, _errorReplyCode,
query, queryLength);
if (*matchString != '.') {
replyWithError(dnsHeader, _errorReplyCode, query, queryLength);
return false;
}
++matchString;
}
return replyWithError(dnsHeader, _errorReplyCode,
query, queryLength);
replyWithError(dnsHeader, _errorReplyCode, query, queryLength);
return false;
}
void DNSServer::processNextRequest()
@ -182,7 +374,14 @@ void DNSServer::processNextRequest()
return;
_udp.read(buffer.get(), currentPacketSize);
respondToRequest(buffer.get(), currentPacketSize);
if (_dns.isSet() && _udp.remoteIP() == _dns) {
// _forwarder may have been set to false; however, for now allow inflight
// replys to finish. //??
forwardReply(buffer.get(), currentPacketSize);
} else
if (respondToRequest(buffer.get(), currentPacketSize)) {
forwardRequest(buffer.get(), currentPacketSize);
}
}
void DNSServer::writeNBOShort(uint16_t value)

View File

@ -2,6 +2,14 @@
#define DNSServer_h
#include <WiFiUdp.h>
// #define DEBUG_DNSSERVER
// https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.txt
#ifndef IANA_DNS_PORT
#define IANA_DNS_PORT 53 // AKA domain
constexpr inline uint16_t kIanaDnsPort = IANA_DNS_PORT;
#endif
#define DNS_QR_QUERY 0
#define DNS_QR_RESPONSE 1
#define DNS_OPCODE_QUERY 0
@ -45,6 +53,15 @@ struct DNSHeader
uint16_t ARCount; // number of resource entries
};
constexpr inline size_t kDNSSQueSizeAddrBits = 3; // The number of bits used to address que entries
constexpr inline size_t kDNSSQueSize = BIT(kDNSSQueSizeAddrBits);
struct DNSS_REQUESTER {
uint32_t ip;
uint16_t port;
uint16_t id;
};
class DNSServer
{
public:
@ -52,24 +69,60 @@ class DNSServer
~DNSServer() {
stop();
};
/*
If specified, `enableForwarder` will update the `domainName` that is used
to match DNS request to this AP's IP Address. A non-matching request will
be forwarded to the DNS server specified by `dns`.
Returns `true` on success.
Returns `false`,
* when forwarding `dns` is not set, or
* unable to allocate resources for managing the DNS forward function.
*/
bool enableForwarder(const String &domainName = emptyString, const IPAddress &dns = (uint32_t)0);
/*
`disableForwarder` will stop forwarding DNS requests. If specified,
updates the `domainName` that is matched for returning this AP's IP Address.
Optionally, resources used for the DNS forward function can be freed.
*/
void disableForwarder(const String &domainName = emptyString, bool freeResources = false);
bool isForwarding() { return _forwarder && _dns.isSet(); }
void setDNS(const IPAddress& dns) { _dns = dns; }
IPAddress getDNS() { return _dns; }
bool isDNSSet() { return _dns.isSet(); }
void processNextRequest();
void setErrorReplyCode(const DNSReplyCode &replyCode);
void setTTL(const uint32_t &ttl);
uint32_t getTTL();
String getDomainName() { return _domainName; }
// Returns true if successful, false if there are no sockets available
bool start(const uint16_t &port,
const String &domainName,
const IPAddress &resolvedIP);
const IPAddress &resolvedIP,
const IPAddress &dns = (uint32_t)0);
// stops the DNS server
void stop();
private:
WiFiUDP _udp;
uint16_t _port;
String _domainName;
unsigned char _resolvedIP[4];
IPAddress _dns;
std::unique_ptr<DNSS_REQUESTER[]> _que;
uint32_t _ttl;
#ifdef DEBUG_DNSSERVER
// There are 2 possiblities for OverFlow:
// 1) we have more than kDNSSQueSize request already outstanding.
// 2) we have request that never received a reply.
uint32_t _que_ov;
uint32_t _que_drop;
#endif
DNSReplyCode _errorReplyCode;
bool _forwarder;
unsigned char _resolvedIP[4];
uint16_t _port;
void downcaseAndRemoveWwwPrefix(String &domainName);
void replyWithIP(DNSHeader *dnsHeader,
@ -81,7 +134,9 @@ class DNSServer
size_t queryLength);
void replyWithError(DNSHeader *dnsHeader,
DNSReplyCode rcode);
void respondToRequest(uint8_t *buffer, size_t length);
bool respondToRequest(uint8_t *buffer, size_t length);
void forwardRequest(uint8_t *buffer, size_t length);
void forwardReply(uint8_t *buffer, size_t length);
void writeNBOShort(uint16_t value);
};
#endif