chiark / gitweb /
resolved: fix cname handling
[elogind.git] / src / resolve / resolved-dns-stream.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2014 Lennart Poettering
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <netinet/tcp.h>
23
24 #include "missing.h"
25 #include "resolved-dns-stream.h"
26
27 #define DNS_STREAM_TIMEOUT_USEC (10 * USEC_PER_SEC)
28 #define DNS_STREAMS_MAX 128
29
30 static void dns_stream_stop(DnsStream *s) {
31         assert(s);
32
33         s->io_event_source = sd_event_source_unref(s->io_event_source);
34         s->timeout_event_source = sd_event_source_unref(s->timeout_event_source);
35         s->fd = safe_close(s->fd);
36 }
37
38 static int dns_stream_update_io(DnsStream *s) {
39         int f = 0;
40
41         assert(s);
42
43         if (s->write_packet && s->n_written < sizeof(s->write_size) + s->write_packet->size)
44                 f |= EPOLLOUT;
45         if (!s->read_packet || s->n_read < sizeof(s->read_size) + s->read_packet->size)
46                 f |= EPOLLIN;
47
48         return sd_event_source_set_io_events(s->io_event_source, f);
49 }
50
51 static int dns_stream_complete(DnsStream *s, int error) {
52         assert(s);
53
54         dns_stream_stop(s);
55
56         if (s->complete)
57                 s->complete(s, error);
58         else
59                 dns_stream_free(s);
60
61         return 0;
62 }
63
64 static int dns_stream_identify(DnsStream *s) {
65         union {
66                 struct cmsghdr header; /* For alignment */
67                 uint8_t buffer[CMSG_SPACE(MAX(sizeof(struct in_pktinfo), sizeof(struct in6_pktinfo)))
68                                + EXTRA_CMSG_SPACE /* kernel appears to require extra space */];
69         } control;
70         struct msghdr mh = {};
71         struct cmsghdr *cmsg;
72         socklen_t sl;
73         int r;
74
75         assert(s);
76
77         if (s->identified)
78                 return 0;
79
80         /* Query the local side */
81         s->local_salen = sizeof(s->local);
82         r = getsockname(s->fd, &s->local.sa, &s->local_salen);
83         if (r < 0)
84                 return -errno;
85         if (s->local.sa.sa_family == AF_INET6 && s->ifindex <= 0)
86                 s->ifindex = s->local.in6.sin6_scope_id;
87
88         /* Query the remote side */
89         s->peer_salen = sizeof(s->peer);
90         r = getpeername(s->fd, &s->peer.sa, &s->peer_salen);
91         if (r < 0)
92                 return -errno;
93         if (s->peer.sa.sa_family == AF_INET6 && s->ifindex <= 0)
94                 s->ifindex = s->peer.in6.sin6_scope_id;
95
96         /* Check consistency */
97         assert(s->peer.sa.sa_family == s->local.sa.sa_family);
98         assert(IN_SET(s->peer.sa.sa_family, AF_INET, AF_INET6));
99
100         /* Query connection meta information */
101         sl = sizeof(control);
102         if (s->peer.sa.sa_family == AF_INET) {
103                 r = getsockopt(s->fd, IPPROTO_IP, IP_PKTOPTIONS, &control, &sl);
104                 if (r < 0)
105                         return -errno;
106         } else if (s->peer.sa.sa_family == AF_INET6) {
107
108                 r = getsockopt(s->fd, IPPROTO_IPV6, IPV6_2292PKTOPTIONS, &control, &sl);
109                 if (r < 0)
110                         return -errno;
111         } else
112                 return -EAFNOSUPPORT;
113
114         mh.msg_control = &control;
115         mh.msg_controllen = sl;
116         for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg)) {
117
118                 if (cmsg->cmsg_level == IPPROTO_IPV6) {
119                         assert(s->peer.sa.sa_family == AF_INET6);
120
121                         switch (cmsg->cmsg_type) {
122
123                         case IPV6_PKTINFO: {
124                                 struct in6_pktinfo *i = (struct in6_pktinfo*) CMSG_DATA(cmsg);
125
126                                 if (s->ifindex <= 0)
127                                         s->ifindex = i->ipi6_ifindex;
128                                 break;
129                         }
130
131                         case IPV6_HOPLIMIT:
132                                 s->ttl = *(int *) CMSG_DATA(cmsg);
133                                 break;
134                         }
135
136                 } else if (cmsg->cmsg_level == IPPROTO_IP) {
137                         assert(s->peer.sa.sa_family == AF_INET);
138
139                         switch (cmsg->cmsg_type) {
140
141                         case IP_PKTINFO: {
142                                 struct in_pktinfo *i = (struct in_pktinfo*) CMSG_DATA(cmsg);
143
144                                 if (s->ifindex <= 0)
145                                         s->ifindex = i->ipi_ifindex;
146                                 break;
147                         }
148
149                         case IP_TTL:
150                                 s->ttl = *(int *) CMSG_DATA(cmsg);
151                                 break;
152                         }
153                 }
154         }
155
156         /* The Linux kernel sets the interface index to the loopback
157          * device if the connection came from the local host since it
158          * avoids the routing table in such a case. Let's unset the
159          * interface index in such a case. */
160         if (s->ifindex > 0 && manager_ifindex_is_loopback(s->manager, s->ifindex) != 0)
161                 s->ifindex = 0;
162
163         /* If we don't know the interface index still, we look for the
164          * first local interface with a matching address. Yuck! */
165         if (s->ifindex <= 0)
166                 s->ifindex = manager_find_ifindex(s->manager, s->local.sa.sa_family, s->local.sa.sa_family == AF_INET ? (union in_addr_union*) &s->local.in.sin_addr : (union in_addr_union*)  &s->local.in6.sin6_addr);
167
168         if (s->protocol == DNS_PROTOCOL_LLMNR && s->ifindex > 0) {
169                 uint32_t ifindex = htobe32(s->ifindex);
170
171                 /* Make sure all packets for this connection are sent on the same interface */
172                 if (s->local.sa.sa_family == AF_INET) {
173                         r = setsockopt(s->fd, IPPROTO_IP, IP_UNICAST_IF, &ifindex, sizeof(ifindex));
174                         if (r < 0)
175                                 return -errno;
176                 } else if (s->local.sa.sa_family == AF_INET6) {
177                         r = setsockopt(s->fd, IPPROTO_IPV6, IPV6_UNICAST_IF, &ifindex, sizeof(ifindex));
178                         if (r < 0)
179                                 return -errno;
180                 }
181         }
182
183         s->identified = true;
184
185         return 0;
186 }
187
188 static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) {
189         DnsStream *s = userdata;
190
191         assert(s);
192
193         return dns_stream_complete(s, ETIMEDOUT);
194 }
195
196 static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) {
197         DnsStream *s = userdata;
198         int r;
199
200         assert(s);
201
202         r = dns_stream_identify(s);
203         if (r < 0)
204                 return dns_stream_complete(s, -r);
205
206         if ((revents & EPOLLOUT) &&
207             s->write_packet &&
208             s->n_written < sizeof(s->write_size) + s->write_packet->size) {
209
210                 struct iovec iov[2];
211                 ssize_t ss;
212
213                 iov[0].iov_base = &s->write_size;
214                 iov[0].iov_len = sizeof(s->write_size);
215                 iov[1].iov_base = DNS_PACKET_DATA(s->write_packet);
216                 iov[1].iov_len = s->write_packet->size;
217
218                 IOVEC_INCREMENT(iov, 2, s->n_written);
219
220                 ss = writev(fd, iov, 2);
221                 if (ss < 0) {
222                         if (errno != EINTR && errno != EAGAIN)
223                                 return dns_stream_complete(s, errno);
224                 } else
225                         s->n_written += ss;
226
227                 /* Are we done? If so, disable the event source for EPOLLOUT */
228                 if (s->n_written >= sizeof(s->write_size) + s->write_packet->size) {
229                         r = dns_stream_update_io(s);
230                         if (r < 0)
231                                 return dns_stream_complete(s, -r);
232                 }
233         }
234
235         if ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) &&
236             (!s->read_packet ||
237              s->n_read < sizeof(s->read_size) + s->read_packet->size)) {
238
239                 if (s->n_read < sizeof(s->read_size)) {
240                         ssize_t ss;
241
242                         ss = read(fd, (uint8_t*) &s->read_size + s->n_read, sizeof(s->read_size) - s->n_read);
243                         if (ss < 0) {
244                                 if (errno != EINTR && errno != EAGAIN)
245                                         return dns_stream_complete(s, errno);
246                         } else if (ss == 0)
247                                 return dns_stream_complete(s, ECONNRESET);
248                         else
249                                 s->n_read += ss;
250                 }
251
252                 if (s->n_read >= sizeof(s->read_size)) {
253
254                         if (be16toh(s->read_size) < DNS_PACKET_HEADER_SIZE)
255                                 return dns_stream_complete(s, EBADMSG);
256
257                         if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) {
258                                 ssize_t ss;
259
260                                 if (!s->read_packet) {
261                                         r = dns_packet_new(&s->read_packet, s->protocol, be16toh(s->read_size));
262                                         if (r < 0)
263                                                 return dns_stream_complete(s, -r);
264
265                                         s->read_packet->size = be16toh(s->read_size);
266                                         s->read_packet->ipproto = IPPROTO_TCP;
267                                         s->read_packet->family = s->peer.sa.sa_family;
268                                         s->read_packet->ttl = s->ttl;
269                                         s->read_packet->ifindex = s->ifindex;
270
271                                         if (s->read_packet->family == AF_INET) {
272                                                 s->read_packet->sender.in = s->peer.in.sin_addr;
273                                                 s->read_packet->sender_port = be16toh(s->peer.in.sin_port);
274                                                 s->read_packet->destination.in = s->local.in.sin_addr;
275                                                 s->read_packet->destination_port = be16toh(s->local.in.sin_port);
276                                         } else {
277                                                 assert(s->read_packet->family == AF_INET6);
278                                                 s->read_packet->sender.in6 = s->peer.in6.sin6_addr;
279                                                 s->read_packet->sender_port = be16toh(s->peer.in6.sin6_port);
280                                                 s->read_packet->destination.in6 = s->local.in6.sin6_addr;
281                                                 s->read_packet->destination_port = be16toh(s->local.in6.sin6_port);
282
283                                                 if (s->read_packet->ifindex == 0)
284                                                         s->read_packet->ifindex = s->peer.in6.sin6_scope_id;
285                                                 if (s->read_packet->ifindex == 0)
286                                                         s->read_packet->ifindex = s->local.in6.sin6_scope_id;
287                                         }
288                                 }
289
290                                 ss = read(fd,
291                                           (uint8_t*) DNS_PACKET_DATA(s->read_packet) + s->n_read - sizeof(s->read_size),
292                                           sizeof(s->read_size) + be16toh(s->read_size) - s->n_read);
293                                 if (ss < 0) {
294                                         if (errno != EINTR && errno != EAGAIN)
295                                                 return dns_stream_complete(s, errno);
296                                 } else if (ss == 0)
297                                         return dns_stream_complete(s, ECONNRESET);
298                                 else
299                                         s->n_read += ss;
300                         }
301
302                         /* Are we done? If so, disable the event source for EPOLLIN */
303                         if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) {
304                                 r = dns_stream_update_io(s);
305                                 if (r < 0)
306                                         return dns_stream_complete(s, -r);
307
308                                 /* If there's a packet handler
309                                  * installed, call that. Note that
310                                  * this is optional... */
311                                 if (s->on_packet)
312                                         return s->on_packet(s);
313                         }
314                 }
315         }
316
317         if ((s->write_packet && s->n_written >= sizeof(s->write_size) + s->write_packet->size) &&
318             (s->read_packet && s->n_read >= sizeof(s->read_size) + s->read_packet->size))
319                 return dns_stream_complete(s, 0);
320
321         return 0;
322 }
323
324 DnsStream *dns_stream_free(DnsStream *s) {
325         if (!s)
326                 return NULL;
327
328         dns_stream_stop(s);
329
330         if (s->manager) {
331                 LIST_REMOVE(streams, s->manager->dns_streams, s);
332                 s->manager->n_dns_streams--;
333         }
334
335         dns_packet_unref(s->write_packet);
336         dns_packet_unref(s->read_packet);
337
338         free(s);
339
340         return 0;
341 }
342
343 DEFINE_TRIVIAL_CLEANUP_FUNC(DnsStream*, dns_stream_free);
344
345 int dns_stream_new(Manager *m, DnsStream **ret, DnsProtocol protocol, int fd) {
346         static const int one = 1;
347         _cleanup_(dns_stream_freep) DnsStream *s = NULL;
348         int r;
349
350         assert(m);
351         assert(fd >= 0);
352
353         if (m->n_dns_streams > DNS_STREAMS_MAX)
354                 return -EBUSY;
355
356         s = new0(DnsStream, 1);
357         if (!s)
358                 return -ENOMEM;
359
360         s->fd = -1;
361         s->protocol = protocol;
362
363         r = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
364         if (r < 0)
365                 return -errno;
366
367         r = sd_event_add_io(m->event, &s->io_event_source, fd, EPOLLIN, on_stream_io, s);
368         if (r < 0)
369                 return r;
370
371         r = sd_event_add_time(m->event, &s->timeout_event_source, CLOCK_MONOTONIC, now(CLOCK_MONOTONIC) + DNS_STREAM_TIMEOUT_USEC, 0, on_stream_timeout, s);
372         if (r < 0)
373                 return r;
374
375         LIST_PREPEND(streams, m->dns_streams, s);
376         s->manager = m;
377         s->fd = fd;
378         m->n_dns_streams++;
379
380         *ret = s;
381         s = NULL;
382
383         return 0;
384 }
385
386 int dns_stream_write_packet(DnsStream *s, DnsPacket *p) {
387         assert(s);
388
389         if (s->write_packet)
390                 return -EBUSY;
391
392         s->write_packet = dns_packet_ref(p);
393         s->write_size = htobe16(p->size);
394         s->n_written = 0;
395
396         return dns_stream_update_io(s);
397 }