chiark / gitweb /
resolve: add llmnr responder side for UDP and TCP
[elogind.git] / src / resolve / resolved-dns-stream.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2014 Lennart Poettering
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <netinet/tcp.h>
23
24 #include "missing.h"
25 #include "resolved-dns-stream.h"
26
27 #define DNS_STREAM_TIMEOUT_USEC (10 * USEC_PER_SEC)
28 #define DNS_STREAMS_MAX 128
29
30 static void dns_stream_stop(DnsStream *s) {
31         assert(s);
32
33         s->io_event_source = sd_event_source_unref(s->io_event_source);
34         s->timeout_event_source = sd_event_source_unref(s->timeout_event_source);
35         s->fd = safe_close(s->fd);
36 }
37
38 static int dns_stream_update_io(DnsStream *s) {
39         int f = 0;
40
41         assert(s);
42
43         if (s->write_packet && s->n_written < sizeof(s->write_size) + s->write_packet->size)
44                 f |= EPOLLOUT;
45         if (!s->read_packet || s->n_read < sizeof(s->read_size) + s->read_packet->size)
46                 f |= EPOLLIN;
47
48         return sd_event_source_set_io_events(s->io_event_source, f);
49 }
50
51 static int stream_complete(DnsStream *s, int error) {
52         assert(s);
53
54         dns_stream_stop(s);
55
56         if (s->complete)
57                 s->complete(s, error);
58         else
59                 dns_stream_free(s);
60
61         return 0;
62 }
63
64 static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) {
65         DnsStream *s = userdata;
66
67         assert(s);
68
69         return stream_complete(s, ETIMEDOUT);
70 }
71
72 static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) {
73         DnsStream *s = userdata;
74         int r;
75
76         assert(s);
77
78         if ((revents & EPOLLOUT) &&
79             s->write_packet &&
80             s->n_written < sizeof(s->write_size) + s->write_packet->size) {
81
82                 struct iovec iov[2];
83                 ssize_t ss;
84
85                 iov[0].iov_base = &s->write_size;
86                 iov[0].iov_len = sizeof(s->write_size);
87                 iov[1].iov_base = DNS_PACKET_DATA(s->write_packet);
88                 iov[1].iov_len = s->write_packet->size;
89
90                 IOVEC_INCREMENT(iov, 2, s->n_written);
91
92                 ss = writev(fd, iov, 2);
93                 if (ss < 0) {
94                         if (errno != EINTR && errno != EAGAIN)
95                                 return stream_complete(s, errno);
96                 } else
97                         s->n_written += ss;
98
99                 /* Are we done? If so, disable the event source for EPOLLOUT */
100                 if (s->n_written >= sizeof(s->write_size) + s->write_packet->size) {
101                         r = dns_stream_update_io(s);
102                         if (r < 0)
103                                 return stream_complete(s, -r);
104                 }
105         }
106
107         if ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) &&
108             (!s->read_packet ||
109              s->n_read < sizeof(s->read_size) + s->read_packet->size)) {
110
111                 if (s->n_read < sizeof(s->read_size)) {
112                         ssize_t ss;
113
114                         ss = read(fd, (uint8_t*) &s->read_size + s->n_read, sizeof(s->read_size) - s->n_read);
115                         if (ss < 0) {
116                                 if (errno != EINTR && errno != EAGAIN)
117                                         return stream_complete(s, errno);
118                         } else if (ss == 0)
119                                 return stream_complete(s, ECONNRESET);
120                         else
121                                 s->n_read += ss;
122                 }
123
124                 if (s->n_read >= sizeof(s->read_size)) {
125
126                         if (be16toh(s->read_size) < DNS_PACKET_HEADER_SIZE)
127                                 return stream_complete(s, EBADMSG);
128
129                         if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) {
130                                 ssize_t ss;
131
132                                 if (!s->read_packet) {
133                                         r = dns_packet_new(&s->read_packet, s->protocol, be16toh(s->read_size));
134                                         if (r < 0)
135                                                 return stream_complete(s, -r);
136
137                                         s->read_packet->size = be16toh(s->read_size);
138                                         s->read_packet->ipproto = IPPROTO_TCP;
139                                         s->read_packet->family = s->peer.sa.sa_family;
140                                         s->read_packet->ttl = s->ttl;
141                                         s->read_packet->ifindex = s->ifindex;
142
143                                         if (s->read_packet->family == AF_INET) {
144                                                 s->read_packet->sender.in = s->peer.in.sin_addr;
145                                                 s->read_packet->sender_port = be16toh(s->peer.in.sin_port);
146                                                 s->read_packet->destination.in = s->local.in.sin_addr;
147                                                 s->read_packet->destination_port = be16toh(s->local.in.sin_port);
148                                         } else {
149                                                 assert(s->read_packet->family == AF_INET6);
150                                                 s->read_packet->sender.in6 = s->peer.in6.sin6_addr;
151                                                 s->read_packet->sender_port = be16toh(s->peer.in6.sin6_port);
152                                                 s->read_packet->destination.in6 = s->local.in6.sin6_addr;
153                                                 s->read_packet->destination_port = be16toh(s->local.in6.sin6_port);
154
155                                                 if (s->read_packet->ifindex == 0)
156                                                         s->read_packet->ifindex = s->peer.in6.sin6_scope_id;
157                                                 if (s->read_packet->ifindex == 0)
158                                                         s->read_packet->ifindex = s->local.in6.sin6_scope_id;
159                                         }
160                                 }
161
162                                 ss = read(fd,
163                                           (uint8_t*) DNS_PACKET_DATA(s->read_packet) + s->n_read - sizeof(s->read_size),
164                                           sizeof(s->read_size) + be16toh(s->read_size) - s->n_read);
165                                 if (ss < 0) {
166                                         if (errno != EINTR && errno != EAGAIN)
167                                                 return stream_complete(s, errno);
168                                 } else if (ss == 0)
169                                         return stream_complete(s, ECONNRESET);
170                                 else
171                                         s->n_read += ss;
172                         }
173
174                         /* Are we done? If so, disable the event source for EPOLLIN */
175                         if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) {
176                                 r = dns_stream_update_io(s);
177                                 if (r < 0)
178                                         return stream_complete(s, -r);
179
180                                 /* If there's a packet handler
181                                  * installed, call that. Note that
182                                  * this is optional... */
183                                 if (s->on_packet)
184                                         return s->on_packet(s);
185                         }
186                 }
187         }
188
189         if ((s->write_packet && s->n_written >= sizeof(s->write_size) + s->write_packet->size) &&
190             (s->read_packet && s->n_read >= sizeof(s->read_size) + s->read_packet->size))
191                 return stream_complete(s, 0);
192
193         return 0;
194 }
195
196 DnsStream *dns_stream_free(DnsStream *s) {
197         if (!s)
198                 return NULL;
199
200         dns_stream_stop(s);
201
202         if (s->manager) {
203                 LIST_REMOVE(streams, s->manager->dns_streams, s);
204                 s->manager->n_dns_streams--;
205         }
206
207         dns_packet_unref(s->write_packet);
208         dns_packet_unref(s->read_packet);
209
210         free(s);
211
212         return 0;
213 }
214
215 DEFINE_TRIVIAL_CLEANUP_FUNC(DnsStream*, dns_stream_free);
216
217 int dns_stream_new(Manager *m, DnsStream **ret, DnsProtocol protocol, int fd) {
218         static const int one = 1;
219         union {
220                 struct cmsghdr header; /* For alignment */
221                 uint8_t buffer[CMSG_SPACE(MAX(sizeof(struct in_pktinfo), sizeof(struct in6_pktinfo)))
222                                + EXTRA_CMSG_SPACE /* kernel appears to require extra space */];
223         } control;
224         struct msghdr mh = {};
225         struct cmsghdr *cmsg;
226         _cleanup_(dns_stream_freep) DnsStream *s = NULL;
227         socklen_t sl;
228         int r;
229
230         assert(m);
231         assert(fd >= 0);
232
233         if (m->n_dns_streams > DNS_STREAMS_MAX)
234                 return -EBUSY;
235
236         s = new0(DnsStream, 1);
237         if (!s)
238                 return -ENOMEM;
239
240         s->fd = -1;
241         s->protocol = protocol;
242
243         /* Query the remote side */
244         s->peer_salen = sizeof(s->peer);
245         r = getpeername(fd, &s->peer.sa, &s->peer_salen);
246         if (r < 0)
247                 return -errno;
248         if (s->peer.sa.sa_family == AF_INET6)
249                 s->ifindex = s->peer.in6.sin6_scope_id;
250
251         /* Query the local side */
252         s->local_salen = sizeof(s->local);
253         r = getsockname(fd, &s->local.sa, &s->local_salen);
254         if (r < 0)
255                 return -errno;
256         if (s->local.sa.sa_family == AF_INET6 && s->ifindex <= 0)
257                 s->ifindex = s->local.in6.sin6_scope_id;
258
259         /* Check consistency */
260         assert(s->peer.sa.sa_family == s->local.sa.sa_family);
261         assert(IN_SET(s->peer.sa.sa_family, AF_INET, AF_INET6));
262
263         /* Query connection meta information */
264         sl = sizeof(control);
265         if (s->peer.sa.sa_family == AF_INET) {
266                 r = getsockopt(fd, IPPROTO_IP, IP_PKTOPTIONS, &control, &sl);
267                 if (r < 0)
268                         return -errno;
269         } else {
270                 assert(s->peer.sa.sa_family == AF_INET6);
271
272                 r = getsockopt(fd, IPPROTO_IPV6, IPV6_2292PKTOPTIONS, &control, &sl);
273                 if (r < 0)
274                         return -errno;
275         }
276
277         mh.msg_control = &control;
278         mh.msg_controllen = sl;
279         for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg)) {
280
281                 if (cmsg->cmsg_level == IPPROTO_IPV6) {
282                         assert(s->peer.sa.sa_family == AF_INET6);
283
284                         switch (cmsg->cmsg_type) {
285
286                         case IPV6_PKTINFO: {
287                                 struct in6_pktinfo *i = (struct in6_pktinfo*) CMSG_DATA(cmsg);
288
289                                 if (s->ifindex <= 0)
290                                         s->ifindex = i->ipi6_ifindex;
291                                 break;
292                         }
293
294                         case IPV6_HOPLIMIT:
295                                 s->ttl = *(int *) CMSG_DATA(cmsg);
296                                 break;
297                         }
298
299                 } else if (cmsg->cmsg_level == IPPROTO_IP) {
300                         assert(s->peer.sa.sa_family == AF_INET);
301
302                         switch (cmsg->cmsg_type) {
303
304                         case IP_PKTINFO: {
305                                 struct in_pktinfo *i = (struct in_pktinfo*) CMSG_DATA(cmsg);
306
307                                 if (s->ifindex <= 0)
308                                         s->ifindex = i->ipi_ifindex;
309                                 break;
310                         }
311
312                         case IP_TTL:
313                                 s->ttl = *(int *) CMSG_DATA(cmsg);
314                                 break;
315                         }
316                 }
317         }
318
319         /* The Linux kernel sets the interface index to the loopback
320          * device if the connection came from the local host since it
321          * avoids the routing table in such a case. Let's unset the
322          * interface index in such a case. */
323         if (s->ifindex > 0 && manager_ifindex_is_loopback(m, s->ifindex) != 0)
324                 s->ifindex = 0;
325
326         /* If we don't know the interface index still, we look for the
327          * first local interface with a matching address. Yuck! */
328         if (s->ifindex <= 0)
329                 s->ifindex = manager_find_ifindex(m, s->local.sa.sa_family, s->local.sa.sa_family == AF_INET ? (union in_addr_union*) &s->local.in.sin_addr : (union in_addr_union*)  &s->local.in6.sin6_addr);
330
331         r = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
332         if (r < 0)
333                 return -errno;
334
335         if (s->protocol == DNS_PROTOCOL_LLMNR && s->ifindex > 0) {
336                 uint32_t ifindex = htobe32(s->ifindex);
337
338                 /* Make sure all packets for this connection are sent on the same interface */
339                 if (s->local.sa.sa_family == AF_INET) {
340                         r = setsockopt(fd, IPPROTO_IP, IP_UNICAST_IF, &ifindex, sizeof(ifindex));
341                         if (r < 0)
342                                 return -errno;
343                 } else if (s->local.sa.sa_family == AF_INET6) {
344                         r = setsockopt(fd, IPPROTO_IPV6, IPV6_UNICAST_IF, &ifindex, sizeof(ifindex));
345                         if (r < 0)
346                                 return -errno;
347                 }
348         }
349
350         r = sd_event_add_io(m->event, &s->io_event_source, fd, EPOLLIN, on_stream_io, s);
351         if (r < 0)
352                 return r;
353
354         r = sd_event_add_time(m->event, &s->timeout_event_source, CLOCK_MONOTONIC, now(CLOCK_MONOTONIC) + DNS_STREAM_TIMEOUT_USEC, 0, on_stream_timeout, s);
355         if (r < 0)
356                 return r;
357
358         LIST_PREPEND(streams, m->dns_streams, s);
359         s->manager = m;
360         s->fd = fd;
361         m->n_dns_streams++;
362
363         *ret = s;
364         s = NULL;
365
366         return 0;
367 }
368
369 int dns_stream_write_packet(DnsStream *s, DnsPacket *p) {
370         assert(s);
371
372         if (s->write_packet)
373                 return -EBUSY;
374
375         s->write_packet = dns_packet_ref(p);
376         s->write_size = htobe16(p->size);
377         s->n_written = 0;
378
379         return dns_stream_update_io(s);
380 }