chiark / gitweb /
resolved: add LLMNR support for looking up names
[elogind.git] / src / resolve / resolved-dns-packet.c
index 9aa073421333912d19ebe0f9d25ecd251243310b..02532dc7212153143b907a3ac4acce3fdc81537b 100644 (file)
  ***/
 
 #include "utf8.h"
-
+#include "util.h"
 #include "resolved-dns-domain.h"
 #include "resolved-dns-packet.h"
 
-int dns_packet_new(DnsPacket **ret, size_t mtu) {
+int dns_packet_new(DnsPacket **ret, DnsProtocol protocol, size_t mtu) {
         DnsPacket *p;
         size_t a;
 
@@ -38,12 +38,20 @@ int dns_packet_new(DnsPacket **ret, size_t mtu) {
         if (a < DNS_PACKET_HEADER_SIZE)
                 a = DNS_PACKET_HEADER_SIZE;
 
+        /* round up to next page size */
+        a = PAGE_ALIGN(ALIGN(sizeof(DnsPacket)) + a) - ALIGN(sizeof(DnsPacket));
+
+        /* make sure we never allocate more than useful */
+        if (a > DNS_PACKET_SIZE_MAX)
+                a = DNS_PACKET_SIZE_MAX;
+
         p = malloc0(ALIGN(sizeof(DnsPacket)) + a);
         if (!p)
                 return -ENOMEM;
 
         p->size = p->rindex = DNS_PACKET_HEADER_SIZE;
         p->allocated = a;
+        p->protocol = protocol;
         p->n_ref = 1;
 
         *ret = p;
@@ -51,19 +59,23 @@ int dns_packet_new(DnsPacket **ret, size_t mtu) {
         return 0;
 }
 
-int dns_packet_new_query(DnsPacket **ret, size_t mtu) {
+int dns_packet_new_query(DnsPacket **ret, DnsProtocol protocol, size_t mtu) {
         DnsPacket *p;
         DnsPacketHeader *h;
         int r;
 
         assert(ret);
 
-        r = dns_packet_new(&p, mtu);
+        r = dns_packet_new(&p, protocol, mtu);
         if (r < 0)
                 return r;
 
         h = DNS_PACKET_HEADER(p);
-        h->flags = htobe16(DNS_PACKET_MAKE_FLAGS(0, 0, 0, 0, 1, 0, 0, 0, 0));
+
+        if (protocol == DNS_PROTOCOL_DNS)
+                h->flags = htobe16(DNS_PACKET_MAKE_FLAGS(0, 0, 0, 0, 1, 0, 0, 0, 0)); /* ask for recursion */
+        else
+                h->flags = htobe16(DNS_PACKET_MAKE_FLAGS(0, 0, 0, 0, 0, 0, 0, 0, 0));
 
         *ret = p;
         return 0;
@@ -84,6 +96,9 @@ static void dns_packet_free(DnsPacket *p) {
 
         assert(p);
 
+        if (p->rrs)
+                dns_resource_record_freev(p->rrs, DNS_PACKET_RRCOUNT(p));
+
         while ((s = hashmap_steal_first_key(p->names)))
                 free(s);
         hashmap_free(p->names);
@@ -112,6 +127,9 @@ int dns_packet_validate(DnsPacket *p) {
         if (p->size < DNS_PACKET_HEADER_SIZE)
                 return -EBADMSG;
 
+        if (p->size > DNS_PACKET_SIZE_MAX)
+                return -EBADMSG;
+
         return 0;
 }
 
@@ -136,8 +154,35 @@ int dns_packet_validate_reply(DnsPacket *p) {
 static int dns_packet_extend(DnsPacket *p, size_t add, void **ret, size_t *start) {
         assert(p);
 
-        if (p->size + add > p->allocated)
-                return -ENOMEM;
+        if (p->size + add > p->allocated) {
+                size_t a;
+
+                a = PAGE_ALIGN((p->size + add) * 2);
+                if (a > DNS_PACKET_SIZE_MAX)
+                        a = DNS_PACKET_SIZE_MAX;
+
+                if (p->size + add > a)
+                        return -EMSGSIZE;
+
+                if (p->data) {
+                        void *d;
+
+                        d = realloc(p->data, a);
+                        if (!d)
+                                return -ENOMEM;
+
+                        p->data = d;
+                } else {
+                        p->data = malloc(a);
+                        if (!p->data)
+                                return -ENOMEM;
+
+                        memcpy(p->data, (uint8_t*) p + ALIGN(sizeof(DnsPacket)), p->size);
+                        memzero((uint8_t*) p->data + p->size, a - p->size);
+                }
+
+                p->allocated = a;
+        }
 
         if (start)
                 *start = p->size;
@@ -358,7 +403,7 @@ int dns_packet_read(DnsPacket *p, size_t sz, const void **ret, size_t *start) {
         return 0;
 }
 
-static void dns_packet_rewind(DnsPacket *p, size_t idx) {
+void dns_packet_rewind(DnsPacket *p, size_t idx) {
         assert(p);
         assert(idx <= p->size);
         assert(idx >= DNS_PACKET_HEADER_SIZE);
@@ -689,11 +734,13 @@ fail:
 }
 
 int dns_packet_skip_question(DnsPacket *p) {
+        unsigned i, n;
         int r;
 
-        unsigned i, n;
         assert(p);
 
+        dns_packet_rewind(p, DNS_PACKET_HEADER_SIZE);
+
         n = DNS_PACKET_QDCOUNT(p);
         for (i = 0; i < n; i++) {
                 _cleanup_(dns_resource_key_free) DnsResourceKey key = {};
@@ -706,6 +753,49 @@ int dns_packet_skip_question(DnsPacket *p) {
         return 0;
 }
 
+int dns_packet_extract_rrs(DnsPacket *p) {
+        DnsResourceRecord **rrs = NULL;
+        size_t saved_rindex;
+        unsigned n, added = 0;
+        int r;
+
+        if (p->rrs)
+                return (int) DNS_PACKET_RRCOUNT(p);
+
+        saved_rindex = p->rindex;
+
+        r = dns_packet_skip_question(p);
+        if (r < 0)
+                goto finish;
+
+        n = DNS_PACKET_RRCOUNT(p);
+        if (n <= 0) {
+                r = 0;
+                goto finish;
+        }
+
+        rrs = new0(DnsResourceRecord*, n);
+        if (!rrs) {
+                r = -ENOMEM;
+                goto finish;
+        }
+
+        for (added = 0; added < n; added++) {
+                r = dns_packet_read_rr(p, &rrs[added], NULL);
+                if (r < 0) {
+                        dns_resource_record_freev(rrs, added);
+                        goto finish;
+                }
+        }
+
+        p->rrs = rrs;
+        r = (int) n;
+
+finish:
+        p->rindex = saved_rindex;
+        return r;
+}
+
 static const char* const dns_rcode_table[_DNS_RCODE_MAX_DEFINED] = {
         [DNS_RCODE_SUCCESS] = "SUCCESS",
         [DNS_RCODE_FORMERR] = "FORMERR",
@@ -727,3 +817,10 @@ static const char* const dns_rcode_table[_DNS_RCODE_MAX_DEFINED] = {
         [DNS_RCODE_BADTRUNC] = "BADTRUNC",
 };
 DEFINE_STRING_TABLE_LOOKUP(dns_rcode, int);
+
+static const char* const dns_protocol_table[_DNS_PROTOCOL_MAX] = {
+        [DNS_PROTOCOL_DNS] = "dns",
+        [DNS_PROTOCOL_MDNS] = "mdns",
+        [DNS_PROTOCOL_LLMNR] = "llmnr",
+};
+DEFINE_STRING_TABLE_LOOKUP(dns_protocol, DnsProtocol);