chiark / gitweb /
b0dc82259128566b1690daa474c952813833cc16
[elogind.git] / src / libsystemd / sd-resolve / sd-resolve.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2005-2008 Lennart Poettering
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <signal.h>
23 #include <unistd.h>
24 #include <stdio.h>
25 #include <string.h>
26 #include <stdlib.h>
27 #include <errno.h>
28 #include <resolv.h>
29 #include <stdint.h>
30 #include <pthread.h>
31 #include <sys/prctl.h>
32 #include <poll.h>
33
34 #include "util.h"
35 #include "list.h"
36 #include "socket-util.h"
37 #include "missing.h"
38 #include "resolve-util.h"
39 #include "sd-resolve.h"
40
41 #define WORKERS_MIN 1U
42 #define WORKERS_MAX 16U
43 #define QUERIES_MAX 256U
44 #define BUFSIZE 10240U
45
46 typedef enum {
47         REQUEST_ADDRINFO,
48         RESPONSE_ADDRINFO,
49         REQUEST_NAMEINFO,
50         RESPONSE_NAMEINFO,
51         REQUEST_RES_QUERY,
52         REQUEST_RES_SEARCH,
53         RESPONSE_RES,
54         REQUEST_TERMINATE,
55         RESPONSE_DIED
56 } QueryType;
57
58 enum {
59         REQUEST_RECV_FD,
60         REQUEST_SEND_FD,
61         RESPONSE_RECV_FD,
62         RESPONSE_SEND_FD,
63         _FD_MAX
64 };
65
66 struct sd_resolve {
67         unsigned n_ref;
68
69         bool dead:1;
70         pid_t original_pid;
71
72         int fds[_FD_MAX];
73
74         pthread_t workers[WORKERS_MAX];
75         unsigned n_valid_workers;
76
77         unsigned current_id;
78         sd_resolve_query* query_array[QUERIES_MAX];
79         unsigned n_queries, n_done, n_outstanding;
80
81         sd_event_source *event_source;
82         sd_event *event;
83
84         sd_resolve_query *current;
85
86         sd_resolve **default_resolve_ptr;
87         pid_t tid;
88
89         LIST_HEAD(sd_resolve_query, queries);
90 };
91
92 struct sd_resolve_query {
93         unsigned n_ref;
94
95         sd_resolve *resolve;
96
97         QueryType type:4;
98         bool done:1;
99         bool floating:1;
100         unsigned id;
101
102         int ret;
103         int _errno;
104         int _h_errno;
105         struct addrinfo *addrinfo;
106         char *serv, *host;
107         unsigned char *answer;
108
109         union {
110                 sd_resolve_getaddrinfo_handler_t getaddrinfo_handler;
111                 sd_resolve_getnameinfo_handler_t getnameinfo_handler;
112                 sd_resolve_res_handler_t res_handler;
113         };
114
115         void *userdata;
116
117         LIST_FIELDS(sd_resolve_query, queries);
118 };
119
120 typedef struct RHeader {
121         QueryType type;
122         unsigned id;
123         size_t length;
124 } RHeader;
125
126 typedef struct AddrInfoRequest {
127         struct RHeader header;
128         bool hints_valid;
129         int ai_flags;
130         int ai_family;
131         int ai_socktype;
132         int ai_protocol;
133         size_t node_len, service_len;
134 } AddrInfoRequest;
135
136 typedef struct AddrInfoResponse {
137         struct RHeader header;
138         int ret;
139         int _errno;
140         int _h_errno;
141         /* followed by addrinfo_serialization[] */
142 } AddrInfoResponse;
143
144 typedef struct AddrInfoSerialization {
145         int ai_flags;
146         int ai_family;
147         int ai_socktype;
148         int ai_protocol;
149         size_t ai_addrlen;
150         size_t canonname_len;
151         /* Followed by ai_addr amd ai_canonname with variable lengths */
152 } AddrInfoSerialization;
153
154 typedef struct NameInfoRequest {
155         struct RHeader header;
156         int flags;
157         socklen_t sockaddr_len;
158         bool gethost:1, getserv:1;
159 } NameInfoRequest;
160
161 typedef struct NameInfoResponse {
162         struct RHeader header;
163         size_t hostlen, servlen;
164         int ret;
165         int _errno;
166         int _h_errno;
167 } NameInfoResponse;
168
169 typedef struct ResRequest {
170         struct RHeader header;
171         int class;
172         int type;
173         size_t dname_len;
174 } ResRequest;
175
176 typedef struct ResResponse {
177         struct RHeader header;
178         int ret;
179         int _errno;
180         int _h_errno;
181 } ResResponse;
182
183 typedef union Packet {
184         RHeader rheader;
185         AddrInfoRequest addrinfo_request;
186         AddrInfoResponse addrinfo_response;
187         NameInfoRequest nameinfo_request;
188         NameInfoResponse nameinfo_response;
189         ResRequest res_request;
190         ResResponse res_response;
191 } Packet;
192
193 static int getaddrinfo_done(sd_resolve_query* q);
194 static int getnameinfo_done(sd_resolve_query *q);
195 static int res_query_done(sd_resolve_query* q);
196
197 static void resolve_query_disconnect(sd_resolve_query *q);
198
199 #define RESOLVE_DONT_DESTROY(resolve) \
200         _cleanup_resolve_unref_ _unused_ sd_resolve *_dont_destroy_##resolve = sd_resolve_ref(resolve)
201
202 static int send_died(int out_fd) {
203
204         RHeader rh = {
205                 .type = RESPONSE_DIED,
206                 .length = sizeof(RHeader),
207         };
208
209         assert(out_fd >= 0);
210
211         if (send(out_fd, &rh, rh.length, MSG_NOSIGNAL) < 0)
212                 return -errno;
213
214         return 0;
215 }
216
217 static void *serialize_addrinfo(void *p, const struct addrinfo *ai, size_t *length, size_t maxlength) {
218         AddrInfoSerialization s;
219         size_t cnl, l;
220
221         assert(p);
222         assert(ai);
223         assert(length);
224         assert(*length <= maxlength);
225
226         cnl = ai->ai_canonname ? strlen(ai->ai_canonname)+1 : 0;
227         l = sizeof(AddrInfoSerialization) + ai->ai_addrlen + cnl;
228
229         if (*length + l > maxlength)
230                 return NULL;
231
232         s.ai_flags = ai->ai_flags;
233         s.ai_family = ai->ai_family;
234         s.ai_socktype = ai->ai_socktype;
235         s.ai_protocol = ai->ai_protocol;
236         s.ai_addrlen = ai->ai_addrlen;
237         s.canonname_len = cnl;
238
239         memcpy((uint8_t*) p, &s, sizeof(AddrInfoSerialization));
240         memcpy((uint8_t*) p + sizeof(AddrInfoSerialization), ai->ai_addr, ai->ai_addrlen);
241
242         if (ai->ai_canonname)
243                 memcpy((char*) p + sizeof(AddrInfoSerialization) + ai->ai_addrlen, ai->ai_canonname, cnl);
244
245         *length += l;
246         return (uint8_t*) p + l;
247 }
248
249 static int send_addrinfo_reply(
250                 int out_fd,
251                 unsigned id,
252                 int ret,
253                 struct addrinfo *ai,
254                 int _errno,
255                 int _h_errno) {
256
257         AddrInfoResponse resp = {
258                 .header.type = RESPONSE_ADDRINFO,
259                 .header.id = id,
260                 .header.length = sizeof(AddrInfoResponse),
261                 .ret = ret,
262                 ._errno = _errno,
263                 ._h_errno = _h_errno,
264         };
265
266         struct msghdr mh = {};
267         struct iovec iov[2];
268         union {
269                 AddrInfoSerialization ais;
270                 uint8_t space[BUFSIZE];
271         } buffer;
272
273         assert(out_fd >= 0);
274
275         if (ret == 0 && ai) {
276                 void *p = &buffer;
277                 struct addrinfo *k;
278
279                 for (k = ai; k; k = k->ai_next) {
280                         p = serialize_addrinfo(p, k, &resp.header.length, (uint8_t*) &buffer + BUFSIZE - (uint8_t*) p);
281                         if (!p) {
282                                 freeaddrinfo(ai);
283                                 return -ENOBUFS;
284                         }
285                 }
286         }
287
288         if (ai)
289                 freeaddrinfo(ai);
290
291         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(AddrInfoResponse) };
292         iov[1] = (struct iovec) { .iov_base = &buffer, .iov_len = resp.header.length - sizeof(AddrInfoResponse) };
293
294         mh.msg_iov = iov;
295         mh.msg_iovlen = ELEMENTSOF(iov);
296
297         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
298                 return -errno;
299
300         return 0;
301 }
302
303 static int send_nameinfo_reply(
304                 int out_fd,
305                 unsigned id,
306                 int ret,
307                 const char *host,
308                 const char *serv,
309                 int _errno,
310                 int _h_errno) {
311
312         NameInfoResponse resp = {
313                 .header.type = RESPONSE_NAMEINFO,
314                 .header.id = id,
315                 .ret = ret,
316                 ._errno = _errno,
317                 ._h_errno = _h_errno,
318         };
319
320         struct msghdr mh = {};
321         struct iovec iov[3];
322         size_t hl, sl;
323
324         assert(out_fd >= 0);
325
326         sl = serv ? strlen(serv)+1 : 0;
327         hl = host ? strlen(host)+1 : 0;
328
329         resp.header.length = sizeof(NameInfoResponse) + hl + sl;
330         resp.hostlen = hl;
331         resp.servlen = sl;
332
333         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(NameInfoResponse) };
334         iov[1] = (struct iovec) { .iov_base = (void*) host, .iov_len = hl };
335         iov[2] = (struct iovec) { .iov_base = (void*) serv, .iov_len = sl };
336
337         mh.msg_iov = iov;
338         mh.msg_iovlen = ELEMENTSOF(iov);
339
340         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
341                 return -errno;
342
343         return 0;
344 }
345
346 static int send_res_reply(int out_fd, unsigned id, const unsigned char *answer, int ret, int _errno, int _h_errno) {
347
348         ResResponse resp = {
349                 .header.type = RESPONSE_RES,
350                 .header.id = id,
351                 .ret = ret,
352                 ._errno = _errno,
353                 ._h_errno = _h_errno,
354         };
355
356         struct msghdr mh = {};
357         struct iovec iov[2];
358         size_t l;
359
360         assert(out_fd >= 0);
361
362         l = ret > 0 ? (size_t) ret : 0;
363
364         resp.header.length = sizeof(ResResponse) + l;
365
366         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(ResResponse) };
367         iov[1] = (struct iovec) { .iov_base = (void*) answer, .iov_len = l };
368
369         mh.msg_iov = iov;
370         mh.msg_iovlen = ELEMENTSOF(iov);
371
372         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
373                 return -errno;
374
375         return 0;
376 }
377
378 static int handle_request(int out_fd, const Packet *packet, size_t length) {
379         const RHeader *req;
380
381         assert(out_fd >= 0);
382         assert(packet);
383
384         req = &packet->rheader;
385
386         assert(length >= sizeof(RHeader));
387         assert(length == req->length);
388
389         switch (req->type) {
390
391         case REQUEST_ADDRINFO: {
392                const AddrInfoRequest *ai_req = &packet->addrinfo_request;
393                struct addrinfo hints = {}, *result = NULL;
394                const char *node, *service;
395                int ret;
396
397                assert(length >= sizeof(AddrInfoRequest));
398                assert(length == sizeof(AddrInfoRequest) + ai_req->node_len + ai_req->service_len);
399
400                hints.ai_flags = ai_req->ai_flags;
401                hints.ai_family = ai_req->ai_family;
402                hints.ai_socktype = ai_req->ai_socktype;
403                hints.ai_protocol = ai_req->ai_protocol;
404
405                node = ai_req->node_len ? (const char*) ai_req + sizeof(AddrInfoRequest) : NULL;
406                service = ai_req->service_len ? (const char*) ai_req + sizeof(AddrInfoRequest) + ai_req->node_len : NULL;
407
408                ret = getaddrinfo(
409                                node, service,
410                                ai_req->hints_valid ? &hints : NULL,
411                                &result);
412
413                /* send_addrinfo_reply() frees result */
414                return send_addrinfo_reply(out_fd, req->id, ret, result, errno, h_errno);
415         }
416
417         case REQUEST_NAMEINFO: {
418                const NameInfoRequest *ni_req = &packet->nameinfo_request;
419                char hostbuf[NI_MAXHOST], servbuf[NI_MAXSERV];
420                union sockaddr_union sa;
421                int ret;
422
423                assert(length >= sizeof(NameInfoRequest));
424                assert(length == sizeof(NameInfoRequest) + ni_req->sockaddr_len);
425                assert(sizeof(sa) >= ni_req->sockaddr_len);
426
427                memcpy(&sa, (const uint8_t *) ni_req + sizeof(NameInfoRequest), ni_req->sockaddr_len);
428
429                ret = getnameinfo(&sa.sa, ni_req->sockaddr_len,
430                                ni_req->gethost ? hostbuf : NULL, ni_req->gethost ? sizeof(hostbuf) : 0,
431                                ni_req->getserv ? servbuf : NULL, ni_req->getserv ? sizeof(servbuf) : 0,
432                                ni_req->flags);
433
434                return send_nameinfo_reply(out_fd, req->id, ret,
435                                ret == 0 && ni_req->gethost ? hostbuf : NULL,
436                                ret == 0 && ni_req->getserv ? servbuf : NULL,
437                                errno, h_errno);
438         }
439
440         case REQUEST_RES_QUERY:
441         case REQUEST_RES_SEARCH: {
442                  const ResRequest *res_req = &packet->res_request;
443                  union {
444                          HEADER header;
445                          uint8_t space[BUFSIZE];
446                  } answer;
447                  const char *dname;
448                  int ret;
449
450                  assert(length >= sizeof(ResRequest));
451                  assert(length == sizeof(ResRequest) + res_req->dname_len);
452
453                  dname = (const char *) res_req + sizeof(ResRequest);
454
455                  if (req->type == REQUEST_RES_QUERY)
456                          ret = res_query(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
457                  else
458                          ret = res_search(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
459
460                  return send_res_reply(out_fd, req->id, (unsigned char *) &answer, ret, errno, h_errno);
461         }
462
463         case REQUEST_TERMINATE:
464                  /* Quit */
465                  return -ECONNRESET;
466
467         default:
468                 assert_not_reached("Unknown request");
469         }
470
471         return 0;
472 }
473
474 static void* thread_worker(void *p) {
475         sd_resolve *resolve = p;
476         sigset_t fullset;
477
478         /* No signals in this thread please */
479         assert_se(sigfillset(&fullset) == 0);
480         assert_se(pthread_sigmask(SIG_BLOCK, &fullset, NULL) == 0);
481
482         /* Assign a pretty name to this thread */
483         prctl(PR_SET_NAME, (unsigned long) "sd-resolve");
484
485         while (!resolve->dead) {
486                 union {
487                         Packet packet;
488                         uint8_t space[BUFSIZE];
489                 } buf;
490                 ssize_t length;
491
492                 length = recv(resolve->fds[REQUEST_RECV_FD], &buf, sizeof(buf), 0);
493                 if (length < 0) {
494                         if (errno == EINTR)
495                                 continue;
496
497                         break;
498                 }
499                 if (length == 0)
500                         break;
501
502                 if (resolve->dead)
503                         break;
504
505                 if (handle_request(resolve->fds[RESPONSE_SEND_FD], &buf.packet, (size_t) length) < 0)
506                         break;
507         }
508
509         send_died(resolve->fds[RESPONSE_SEND_FD]);
510
511         return NULL;
512 }
513
514 static int start_threads(sd_resolve *resolve, unsigned extra) {
515         unsigned n;
516         int r;
517
518         n = resolve->n_outstanding + extra;
519         n = CLAMP(n, WORKERS_MIN, WORKERS_MAX);
520
521         while (resolve->n_valid_workers < n) {
522
523                 r = pthread_create(&resolve->workers[resolve->n_valid_workers], NULL, thread_worker, resolve);
524                 if (r != 0)
525                         return -r;
526
527                 resolve->n_valid_workers ++;
528         }
529
530         return 0;
531 }
532
533 static bool resolve_pid_changed(sd_resolve *r) {
534         assert(r);
535
536         /* We don't support people creating a resolver and keeping it
537          * around after fork(). Let's complain. */
538
539         return r->original_pid != getpid();
540 }
541
542 _public_ int sd_resolve_new(sd_resolve **ret) {
543         sd_resolve *resolve = NULL;
544         int i, r;
545
546         assert_return(ret, -EINVAL);
547
548         resolve = new0(sd_resolve, 1);
549         if (!resolve)
550                 return -ENOMEM;
551
552         resolve->n_ref = 1;
553         resolve->original_pid = getpid();
554
555         for (i = 0; i < _FD_MAX; i++)
556                 resolve->fds[i] = -1;
557
558         r = socketpair(PF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0, resolve->fds + REQUEST_RECV_FD);
559         if (r < 0) {
560                 r = -errno;
561                 goto fail;
562         }
563
564         r = socketpair(PF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0, resolve->fds + RESPONSE_RECV_FD);
565         if (r < 0) {
566                 r = -errno;
567                 goto fail;
568         }
569
570         fd_inc_sndbuf(resolve->fds[REQUEST_SEND_FD], QUERIES_MAX * BUFSIZE);
571         fd_inc_rcvbuf(resolve->fds[REQUEST_RECV_FD], QUERIES_MAX * BUFSIZE);
572         fd_inc_sndbuf(resolve->fds[RESPONSE_SEND_FD], QUERIES_MAX * BUFSIZE);
573         fd_inc_rcvbuf(resolve->fds[RESPONSE_RECV_FD], QUERIES_MAX * BUFSIZE);
574
575         fd_nonblock(resolve->fds[RESPONSE_RECV_FD], true);
576
577         *ret = resolve;
578         return 0;
579
580 fail:
581         sd_resolve_unref(resolve);
582         return r;
583 }
584
585 _public_ int sd_resolve_default(sd_resolve **ret) {
586
587         static thread_local sd_resolve *default_resolve = NULL;
588         sd_resolve *e = NULL;
589         int r;
590
591         if (!ret)
592                 return !!default_resolve;
593
594         if (default_resolve) {
595                 *ret = sd_resolve_ref(default_resolve);
596                 return 0;
597         }
598
599         r = sd_resolve_new(&e);
600         if (r < 0)
601                 return r;
602
603         e->default_resolve_ptr = &default_resolve;
604         e->tid = gettid();
605         default_resolve = e;
606
607         *ret = e;
608         return 1;
609 }
610
611 _public_ int sd_resolve_get_tid(sd_resolve *resolve, pid_t *tid) {
612         assert_return(resolve, -EINVAL);
613         assert_return(tid, -EINVAL);
614         assert_return(!resolve_pid_changed(resolve), -ECHILD);
615
616         if (resolve->tid != 0) {
617                 *tid = resolve->tid;
618                 return 0;
619         }
620
621         if (resolve->event)
622                 return sd_event_get_tid(resolve->event, tid);
623
624         return -ENXIO;
625 }
626
627 static void resolve_free(sd_resolve *resolve) {
628         PROTECT_ERRNO;
629         sd_resolve_query *q;
630         unsigned i;
631
632         assert(resolve);
633
634         while ((q = resolve->queries)) {
635                 assert(q->floating);
636                 resolve_query_disconnect(q);
637                 sd_resolve_query_unref(q);
638         }
639
640         if (resolve->default_resolve_ptr)
641                 *(resolve->default_resolve_ptr) = NULL;
642
643         resolve->dead = true;
644
645         sd_resolve_detach_event(resolve);
646
647         if (resolve->fds[REQUEST_SEND_FD] >= 0) {
648
649                 RHeader req = {
650                         .type = REQUEST_TERMINATE,
651                         .length = sizeof(req)
652                 };
653
654                 /* Send one termination packet for each worker */
655                 for (i = 0; i < resolve->n_valid_workers; i++)
656                         (void) send(resolve->fds[REQUEST_SEND_FD], &req, req.length, MSG_NOSIGNAL);
657         }
658
659         /* Now terminate them and wait until they are gone. */
660         for (i = 0; i < resolve->n_valid_workers; i++) {
661                 for (;;) {
662                         if (pthread_join(resolve->workers[i], NULL) != EINTR)
663                                 break;
664                 }
665         }
666
667         /* Close all communication channels */
668         for (i = 0; i < _FD_MAX; i++)
669                 safe_close(resolve->fds[i]);
670
671         free(resolve);
672 }
673
674 _public_ sd_resolve* sd_resolve_ref(sd_resolve *resolve) {
675         assert_return(resolve, NULL);
676
677         assert(resolve->n_ref >= 1);
678         resolve->n_ref++;
679
680         return resolve;
681 }
682
683 _public_ sd_resolve* sd_resolve_unref(sd_resolve *resolve) {
684
685         if (!resolve)
686                 return NULL;
687
688         assert(resolve->n_ref >= 1);
689         resolve->n_ref--;
690
691         if (resolve->n_ref <= 0)
692                 resolve_free(resolve);
693
694         return NULL;
695 }
696
697 _public_ int sd_resolve_get_fd(sd_resolve *resolve) {
698         assert_return(resolve, -EINVAL);
699         assert_return(!resolve_pid_changed(resolve), -ECHILD);
700
701         return resolve->fds[RESPONSE_RECV_FD];
702 }
703
704 _public_ int sd_resolve_get_events(sd_resolve *resolve) {
705         assert_return(resolve, -EINVAL);
706         assert_return(!resolve_pid_changed(resolve), -ECHILD);
707
708         return resolve->n_queries > resolve->n_done ? POLLIN : 0;
709 }
710
711 _public_ int sd_resolve_get_timeout(sd_resolve *resolve, uint64_t *usec) {
712         assert_return(resolve, -EINVAL);
713         assert_return(usec, -EINVAL);
714         assert_return(!resolve_pid_changed(resolve), -ECHILD);
715
716         *usec = (uint64_t) -1;
717         return 0;
718 }
719
720 static sd_resolve_query *lookup_query(sd_resolve *resolve, unsigned id) {
721         sd_resolve_query *q;
722
723         assert(resolve);
724
725         q = resolve->query_array[id % QUERIES_MAX];
726         if (q)
727                 if (q->id == id)
728                         return q;
729
730         return NULL;
731 }
732
733 static int complete_query(sd_resolve *resolve, sd_resolve_query *q) {
734         int r;
735
736         assert(q);
737         assert(!q->done);
738         assert(q->resolve == resolve);
739
740         q->done = true;
741         resolve->n_done ++;
742
743         resolve->current = sd_resolve_query_ref(q);
744
745         switch (q->type) {
746
747         case REQUEST_ADDRINFO:
748                 r = getaddrinfo_done(q);
749                 break;
750
751         case REQUEST_NAMEINFO:
752                 r = getnameinfo_done(q);
753                 break;
754
755         case REQUEST_RES_QUERY:
756         case REQUEST_RES_SEARCH:
757                 r = res_query_done(q);
758                 break;
759
760         default:
761                 assert_not_reached("Cannot complete unknown query type");
762         }
763
764         resolve->current = NULL;
765
766         if (q->floating) {
767                 resolve_query_disconnect(q);
768                 sd_resolve_query_unref(q);
769         }
770
771         sd_resolve_query_unref(q);
772
773         return r;
774 }
775
776 static int unserialize_addrinfo(const void **p, size_t *length, struct addrinfo **ret_ai) {
777         AddrInfoSerialization s;
778         size_t l;
779         struct addrinfo *ai;
780
781         assert(p);
782         assert(*p);
783         assert(ret_ai);
784         assert(length);
785
786         if (*length < sizeof(AddrInfoSerialization))
787                 return -EBADMSG;
788
789         memcpy(&s, *p, sizeof(s));
790
791         l = sizeof(AddrInfoSerialization) + s.ai_addrlen + s.canonname_len;
792         if (*length < l)
793                 return -EBADMSG;
794
795         ai = new0(struct addrinfo, 1);
796         if (!ai)
797                 return -ENOMEM;
798
799         ai->ai_flags = s.ai_flags;
800         ai->ai_family = s.ai_family;
801         ai->ai_socktype = s.ai_socktype;
802         ai->ai_protocol = s.ai_protocol;
803         ai->ai_addrlen = s.ai_addrlen;
804
805         if (s.ai_addrlen > 0) {
806                 ai->ai_addr = memdup((const uint8_t*) *p + sizeof(AddrInfoSerialization), s.ai_addrlen);
807                 if (!ai->ai_addr) {
808                         free(ai);
809                         return -ENOMEM;
810                 }
811         }
812
813         if (s.canonname_len > 0) {
814                 ai->ai_canonname = memdup((const uint8_t*) *p + sizeof(AddrInfoSerialization) + s.ai_addrlen, s.canonname_len);
815                 if (!ai->ai_canonname) {
816                         free(ai->ai_addr);
817                         free(ai);
818                         return -ENOMEM;
819                 }
820         }
821
822         *length -= l;
823         *ret_ai = ai;
824         *p = ((const uint8_t*) *p) + l;
825
826         return 0;
827 }
828
829 static int handle_response(sd_resolve *resolve, const Packet *packet, size_t length) {
830         const RHeader *resp;
831         sd_resolve_query *q;
832         int r;
833
834         assert(resolve);
835
836         resp = &packet->rheader;
837         assert(resp);
838         assert(length >= sizeof(RHeader));
839         assert(length == resp->length);
840
841         if (resp->type == RESPONSE_DIED) {
842                 resolve->dead = true;
843                 return 0;
844         }
845
846         assert(resolve->n_outstanding > 0);
847         resolve->n_outstanding--;
848
849         q = lookup_query(resolve, resp->id);
850         if (!q)
851                 return 0;
852
853         switch (resp->type) {
854
855         case RESPONSE_ADDRINFO: {
856                 const AddrInfoResponse *ai_resp = &packet->addrinfo_response;
857                 const void *p;
858                 size_t l;
859                 struct addrinfo *prev = NULL;
860
861                 assert(length >= sizeof(AddrInfoResponse));
862                 assert(q->type == REQUEST_ADDRINFO);
863
864                 q->ret = ai_resp->ret;
865                 q->_errno = ai_resp->_errno;
866                 q->_h_errno = ai_resp->_h_errno;
867
868                 l = length - sizeof(AddrInfoResponse);
869                 p = (const uint8_t*) resp + sizeof(AddrInfoResponse);
870
871                 while (l > 0 && p) {
872                         struct addrinfo *ai = NULL;
873
874                         r = unserialize_addrinfo(&p, &l, &ai);
875                         if (r < 0) {
876                                 q->ret = EAI_SYSTEM;
877                                 q->_errno = -r;
878                                 q->_h_errno = 0;
879                                 freeaddrinfo(q->addrinfo);
880                                 q->addrinfo = NULL;
881                                 break;
882                         }
883
884                         if (prev)
885                                 prev->ai_next = ai;
886                         else
887                                 q->addrinfo = ai;
888
889                         prev = ai;
890                 }
891
892                 return complete_query(resolve, q);
893         }
894
895         case RESPONSE_NAMEINFO: {
896                 const NameInfoResponse *ni_resp = &packet->nameinfo_response;
897
898                 assert(length >= sizeof(NameInfoResponse));
899                 assert(q->type == REQUEST_NAMEINFO);
900
901                 q->ret = ni_resp->ret;
902                 q->_errno = ni_resp->_errno;
903                 q->_h_errno = ni_resp->_h_errno;
904
905                 if (ni_resp->hostlen > 0) {
906                         q->host = strndup((const char*) ni_resp + sizeof(NameInfoResponse), ni_resp->hostlen-1);
907                         if (!q->host) {
908                                 q->ret = EAI_MEMORY;
909                                 q->_errno = ENOMEM;
910                                 q->_h_errno = 0;
911                         }
912                 }
913
914                 if (ni_resp->servlen > 0) {
915                         q->serv = strndup((const char*) ni_resp + sizeof(NameInfoResponse) + ni_resp->hostlen, ni_resp->servlen-1);
916                         if (!q->serv) {
917                                 q->ret = EAI_MEMORY;
918                                 q->_errno = ENOMEM;
919                                 q->_h_errno = 0;
920                         }
921                 }
922
923                 return complete_query(resolve, q);
924         }
925
926         case RESPONSE_RES: {
927                 const ResResponse *res_resp = &packet->res_response;
928
929                 assert(length >= sizeof(ResResponse));
930                 assert(q->type == REQUEST_RES_QUERY || q->type == REQUEST_RES_SEARCH);
931
932                 q->ret = res_resp->ret;
933                 q->_errno = res_resp->_errno;
934                 q->_h_errno = res_resp->_h_errno;
935
936                 if (res_resp->ret >= 0)  {
937                         q->answer = memdup((const char *)resp + sizeof(ResResponse), res_resp->ret);
938                         if (!q->answer) {
939                                 q->ret = -1;
940                                 q->_errno = ENOMEM;
941                                 q->_h_errno = 0;
942                         }
943                 }
944
945                 return complete_query(resolve, q);
946         }
947
948         default:
949                 return 0;
950         }
951 }
952
953 _public_ int sd_resolve_process(sd_resolve *resolve) {
954         RESOLVE_DONT_DESTROY(resolve);
955
956         union {
957                 Packet packet;
958                 uint8_t space[BUFSIZE];
959         } buf;
960         ssize_t l;
961         int r;
962
963         assert_return(resolve, -EINVAL);
964         assert_return(!resolve_pid_changed(resolve), -ECHILD);
965
966         /* We don't allow recursively invoking sd_resolve_process(). */
967         assert_return(!resolve->current, -EBUSY);
968
969         l = recv(resolve->fds[RESPONSE_RECV_FD], &buf, sizeof(buf), 0);
970         if (l < 0) {
971                 if (errno == EAGAIN)
972                         return 0;
973
974                 return -errno;
975         }
976         if (l == 0)
977                 return -ECONNREFUSED;
978
979         r = handle_response(resolve, &buf.packet, (size_t) l);
980         if (r < 0)
981                 return r;
982
983         return 1;
984 }
985
986 _public_ int sd_resolve_wait(sd_resolve *resolve, uint64_t timeout_usec) {
987         int r;
988
989         assert_return(resolve, -EINVAL);
990         assert_return(!resolve_pid_changed(resolve), -ECHILD);
991
992         if (resolve->n_done >= resolve->n_queries)
993                 return 0;
994
995         do {
996                 r = fd_wait_for_event(resolve->fds[RESPONSE_RECV_FD], POLLIN, timeout_usec);
997         } while (r == -EINTR);
998
999         if (r < 0)
1000                 return r;
1001
1002         return sd_resolve_process(resolve);
1003 }
1004
1005 static int alloc_query(sd_resolve *resolve, bool floating, sd_resolve_query **_q) {
1006         sd_resolve_query *q;
1007         int r;
1008
1009         assert(resolve);
1010         assert(_q);
1011
1012         if (resolve->n_queries >= QUERIES_MAX)
1013                 return -ENOBUFS;
1014
1015         r = start_threads(resolve, 1);
1016         if (r < 0)
1017                 return r;
1018
1019         while (resolve->query_array[resolve->current_id % QUERIES_MAX])
1020                 resolve->current_id++;
1021
1022         q = resolve->query_array[resolve->current_id % QUERIES_MAX] = new0(sd_resolve_query, 1);
1023         if (!q)
1024                 return -ENOMEM;
1025
1026         q->n_ref = 1;
1027         q->resolve = resolve;
1028         q->floating = floating;
1029         q->id = resolve->current_id++;
1030
1031         if (!floating)
1032                 sd_resolve_ref(resolve);
1033
1034         LIST_PREPEND(queries, resolve->queries, q);
1035         resolve->n_queries++;
1036
1037         *_q = q;
1038         return 0;
1039 }
1040
1041 _public_ int sd_resolve_getaddrinfo(
1042                 sd_resolve *resolve,
1043                 sd_resolve_query **_q,
1044                 const char *node, const char *service,
1045                 const struct addrinfo *hints,
1046                 sd_resolve_getaddrinfo_handler_t callback, void *userdata) {
1047
1048         AddrInfoRequest req = {};
1049         struct msghdr mh = {};
1050         struct iovec iov[3];
1051         sd_resolve_query *q;
1052         int r;
1053
1054         assert_return(resolve, -EINVAL);
1055         assert_return(node || service, -EINVAL);
1056         assert_return(callback, -EINVAL);
1057         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1058
1059         r = alloc_query(resolve, !_q, &q);
1060         if (r < 0)
1061                 return r;
1062
1063         q->type = REQUEST_ADDRINFO;
1064         q->getaddrinfo_handler = callback;
1065         q->userdata = userdata;
1066
1067         req.node_len = node ? strlen(node)+1 : 0;
1068         req.service_len = service ? strlen(service)+1 : 0;
1069
1070         req.header.id = q->id;
1071         req.header.type = REQUEST_ADDRINFO;
1072         req.header.length = sizeof(AddrInfoRequest) + req.node_len + req.service_len;
1073
1074         if (hints) {
1075                 req.hints_valid = true;
1076                 req.ai_flags = hints->ai_flags;
1077                 req.ai_family = hints->ai_family;
1078                 req.ai_socktype = hints->ai_socktype;
1079                 req.ai_protocol = hints->ai_protocol;
1080         }
1081
1082         iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(AddrInfoRequest) };
1083         if (node)
1084                 iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = (void*) node, .iov_len = req.node_len };
1085         if (service)
1086                 iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = (void*) service, .iov_len = req.service_len };
1087         mh.msg_iov = iov;
1088
1089         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1090                 sd_resolve_query_unref(q);
1091                 return -errno;
1092         }
1093
1094         resolve->n_outstanding++;
1095
1096         if (_q)
1097                 *_q = q;
1098
1099         return 0;
1100 }
1101
1102 static int getaddrinfo_done(sd_resolve_query* q) {
1103         assert(q);
1104         assert(q->done);
1105         assert(q->getaddrinfo_handler);
1106
1107         errno = q->_errno;
1108         h_errno = q->_h_errno;
1109
1110         return q->getaddrinfo_handler(q, q->ret, q->addrinfo, q->userdata);
1111 }
1112
1113 _public_ int sd_resolve_getnameinfo(
1114                 sd_resolve *resolve,
1115                 sd_resolve_query**_q,
1116                 const struct sockaddr *sa, socklen_t salen,
1117                 int flags,
1118                 uint64_t get,
1119                 sd_resolve_getnameinfo_handler_t callback,
1120                 void *userdata) {
1121
1122         NameInfoRequest req = {};
1123         struct msghdr mh = {};
1124         struct iovec iov[2];
1125         sd_resolve_query *q;
1126         int r;
1127
1128         assert_return(resolve, -EINVAL);
1129         assert_return(sa, -EINVAL);
1130         assert_return(salen >= sizeof(struct sockaddr), -EINVAL);
1131         assert_return(salen <= sizeof(union sockaddr_union), -EINVAL);
1132         assert_return((get & ~SD_RESOLVE_GET_BOTH) == 0, -EINVAL);
1133         assert_return(callback, -EINVAL);
1134         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1135
1136         r = alloc_query(resolve, !_q, &q);
1137         if (r < 0)
1138                 return r;
1139
1140         q->type = REQUEST_NAMEINFO;
1141         q->getnameinfo_handler = callback;
1142         q->userdata = userdata;
1143
1144         req.header.id = q->id;
1145         req.header.type = REQUEST_NAMEINFO;
1146         req.header.length = sizeof(NameInfoRequest) + salen;
1147
1148         req.flags = flags;
1149         req.sockaddr_len = salen;
1150         req.gethost = !!(get & SD_RESOLVE_GET_HOST);
1151         req.getserv = !!(get & SD_RESOLVE_GET_SERVICE);
1152
1153         iov[0] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(NameInfoRequest) };
1154         iov[1] = (struct iovec) { .iov_base = (void*) sa, .iov_len = salen };
1155
1156         mh.msg_iov = iov;
1157         mh.msg_iovlen = 2;
1158
1159         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1160                 sd_resolve_query_unref(q);
1161                 return -errno;
1162         }
1163
1164         resolve->n_outstanding++;
1165
1166         if (_q)
1167                 *_q = q;
1168
1169         return 0;
1170 }
1171
1172 static int getnameinfo_done(sd_resolve_query *q) {
1173
1174         assert(q);
1175         assert(q->done);
1176         assert(q->getnameinfo_handler);
1177
1178         errno = q->_errno;
1179         h_errno= q->_h_errno;
1180
1181         return q->getnameinfo_handler(q, q->ret, q->host, q->serv, q->userdata);
1182 }
1183
1184 static int resolve_res(
1185                 sd_resolve *resolve,
1186                 sd_resolve_query **_q,
1187                 QueryType qtype,
1188                 const char *dname,
1189                 int class, int type,
1190                 sd_resolve_res_handler_t callback, void *userdata) {
1191
1192         struct msghdr mh = {};
1193         struct iovec iov[2];
1194         ResRequest req = {};
1195         sd_resolve_query *q;
1196         int r;
1197
1198         assert_return(resolve, -EINVAL);
1199         assert_return(dname, -EINVAL);
1200         assert_return(callback, -EINVAL);
1201         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1202
1203         r = alloc_query(resolve, !_q, &q);
1204         if (r < 0)
1205                 return r;
1206
1207         q->type = qtype;
1208         q->res_handler = callback;
1209         q->userdata = userdata;
1210
1211         req.dname_len = strlen(dname) + 1;
1212         req.class = class;
1213         req.type = type;
1214
1215         req.header.id = q->id;
1216         req.header.type = qtype;
1217         req.header.length = sizeof(ResRequest) + req.dname_len;
1218
1219         iov[0] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(ResRequest) };
1220         iov[1] = (struct iovec) { .iov_base = (void*) dname, .iov_len = req.dname_len };
1221
1222         mh.msg_iov = iov;
1223         mh.msg_iovlen = 2;
1224
1225         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1226                 sd_resolve_query_unref(q);
1227                 return -errno;
1228         }
1229
1230         resolve->n_outstanding++;
1231
1232         if (_q)
1233                 *_q = q;
1234
1235         return 0;
1236 }
1237
1238 _public_ int sd_resolve_res_query(sd_resolve *resolve, sd_resolve_query** q, const char *dname, int class, int type, sd_resolve_res_handler_t callback, void *userdata) {
1239         return resolve_res(resolve, q, REQUEST_RES_QUERY, dname, class, type, callback, userdata);
1240 }
1241
1242 _public_ int sd_resolve_res_search(sd_resolve *resolve, sd_resolve_query** q, const char *dname, int class, int type, sd_resolve_res_handler_t callback, void *userdata) {
1243         return resolve_res(resolve, q, REQUEST_RES_SEARCH, dname, class, type, callback, userdata);
1244 }
1245
1246 static int res_query_done(sd_resolve_query* q) {
1247         assert(q);
1248         assert(q->done);
1249         assert(q->res_handler);
1250
1251         errno = q->_errno;
1252         h_errno = q->_h_errno;
1253
1254         return q->res_handler(q, q->ret, q->answer, q->userdata);
1255 }
1256
1257 _public_ sd_resolve_query* sd_resolve_query_ref(sd_resolve_query *q) {
1258         assert_return(q, NULL);
1259
1260         assert(q->n_ref >= 1);
1261         q->n_ref++;
1262
1263         return q;
1264 }
1265
1266 static void resolve_freeaddrinfo(struct addrinfo *ai) {
1267         while (ai) {
1268                 struct addrinfo *next = ai->ai_next;
1269
1270                 free(ai->ai_addr);
1271                 free(ai->ai_canonname);
1272                 free(ai);
1273                 ai = next;
1274         }
1275 }
1276
1277 static void resolve_query_disconnect(sd_resolve_query *q) {
1278         sd_resolve *resolve;
1279         unsigned i;
1280
1281         assert(q);
1282
1283         if (!q->resolve)
1284                 return;
1285
1286         resolve = q->resolve;
1287         assert(resolve->n_queries > 0);
1288
1289         if (q->done) {
1290                 assert(resolve->n_done > 0);
1291                 resolve->n_done--;
1292         }
1293
1294         i = q->id % QUERIES_MAX;
1295         assert(resolve->query_array[i] == q);
1296         resolve->query_array[i] = NULL;
1297         LIST_REMOVE(queries, resolve->queries, q);
1298         resolve->n_queries--;
1299
1300         q->resolve = NULL;
1301         if (!q->floating)
1302                 sd_resolve_unref(resolve);
1303 }
1304
1305 static void resolve_query_free(sd_resolve_query *q) {
1306         assert(q);
1307
1308         resolve_query_disconnect(q);
1309
1310         resolve_freeaddrinfo(q->addrinfo);
1311         free(q->host);
1312         free(q->serv);
1313         free(q->answer);
1314         free(q);
1315 }
1316
1317 _public_ sd_resolve_query* sd_resolve_query_unref(sd_resolve_query* q) {
1318         if (!q)
1319                 return NULL;
1320
1321         assert(q->n_ref >= 1);
1322         q->n_ref--;
1323
1324         if (q->n_ref <= 0)
1325                 resolve_query_free(q);
1326
1327         return NULL;
1328 }
1329
1330 _public_ int sd_resolve_query_is_done(sd_resolve_query *q) {
1331         assert_return(q, -EINVAL);
1332         assert_return(!resolve_pid_changed(q->resolve), -ECHILD);
1333
1334         return q->done;
1335 }
1336
1337 _public_ void* sd_resolve_query_set_userdata(sd_resolve_query *q, void *userdata) {
1338         void *ret;
1339
1340         assert_return(q, NULL);
1341         assert_return(!resolve_pid_changed(q->resolve), NULL);
1342
1343         ret = q->userdata;
1344         q->userdata = userdata;
1345
1346         return ret;
1347 }
1348
1349 _public_ void* sd_resolve_query_get_userdata(sd_resolve_query *q) {
1350         assert_return(q, NULL);
1351         assert_return(!resolve_pid_changed(q->resolve), NULL);
1352
1353         return q->userdata;
1354 }
1355
1356 _public_ sd_resolve *sd_resolve_query_get_resolve(sd_resolve_query *q) {
1357         assert_return(q, NULL);
1358         assert_return(!resolve_pid_changed(q->resolve), NULL);
1359
1360         return q->resolve;
1361 }
1362
1363 static int io_callback(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
1364         sd_resolve *resolve = userdata;
1365         int r;
1366
1367         assert(resolve);
1368
1369         r = sd_resolve_process(resolve);
1370         if (r < 0)
1371                 return r;
1372
1373         return 1;
1374 }
1375
1376 _public_ int sd_resolve_attach_event(sd_resolve *resolve, sd_event *event, int priority) {
1377         int r;
1378
1379         assert_return(resolve, -EINVAL);
1380         assert_return(!resolve->event, -EBUSY);
1381
1382         assert(!resolve->event_source);
1383
1384         if (event)
1385                 resolve->event = sd_event_ref(event);
1386         else {
1387                 r = sd_event_default(&resolve->event);
1388                 if (r < 0)
1389                         return r;
1390         }
1391
1392         r = sd_event_add_io(resolve->event, &resolve->event_source, resolve->fds[RESPONSE_RECV_FD], POLLIN, io_callback, resolve);
1393         if (r < 0)
1394                 goto fail;
1395
1396         r = sd_event_source_set_priority(resolve->event_source, priority);
1397         if (r < 0)
1398                 goto fail;
1399
1400         return 0;
1401
1402 fail:
1403         sd_resolve_detach_event(resolve);
1404         return r;
1405 }
1406
1407 _public_  int sd_resolve_detach_event(sd_resolve *resolve) {
1408         assert_return(resolve, -EINVAL);
1409
1410         if (!resolve->event)
1411                 return 0;
1412
1413         if (resolve->event_source) {
1414                 sd_event_source_set_enabled(resolve->event_source, SD_EVENT_OFF);
1415                 resolve->event_source = sd_event_source_unref(resolve->event_source);
1416         }
1417
1418         resolve->event = sd_event_unref(resolve->event);
1419         return 1;
1420 }
1421
1422 _public_ sd_event *sd_resolve_get_event(sd_resolve *resolve) {
1423         assert_return(resolve, NULL);
1424
1425         return resolve->event;
1426 }