chiark / gitweb /
sd-resolve: scale number of threads by queries currently being processed, rather...
[elogind.git] / src / libsystemd / sd-resolve / sd-resolve.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2005-2008 Lennart Poettering
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <assert.h>
23 #include <fcntl.h>
24 #include <signal.h>
25 #include <unistd.h>
26 #include <sys/select.h>
27 #include <stdio.h>
28 #include <string.h>
29 #include <stdlib.h>
30 #include <errno.h>
31 #include <sys/wait.h>
32 #include <sys/types.h>
33 #include <pwd.h>
34 #include <netinet/in.h>
35 #include <arpa/nameser.h>
36 #include <resolv.h>
37 #include <dirent.h>
38 #include <sys/time.h>
39 #include <sys/resource.h>
40 #include <stdint.h>
41 #include <pthread.h>
42 #include <sys/prctl.h>
43 #include <sys/poll.h>
44
45 #include "util.h"
46 #include "list.h"
47 #include "socket-util.h"
48 #include "missing.h"
49 #include "resolve-util.h"
50 #include "sd-resolve.h"
51
52 #define WORKERS_MIN 1U
53 #define WORKERS_MAX 16U
54 #define QUERIES_MAX 256U
55 #define BUFSIZE 10240U
56
57 typedef enum {
58         REQUEST_ADDRINFO,
59         RESPONSE_ADDRINFO,
60         REQUEST_NAMEINFO,
61         RESPONSE_NAMEINFO,
62         REQUEST_RES_QUERY,
63         REQUEST_RES_SEARCH,
64         RESPONSE_RES,
65         REQUEST_TERMINATE,
66         RESPONSE_DIED
67 } QueryType;
68
69 enum {
70         REQUEST_RECV_FD,
71         REQUEST_SEND_FD,
72         RESPONSE_RECV_FD,
73         RESPONSE_SEND_FD,
74         _FD_MAX
75 };
76
77 struct sd_resolve {
78         unsigned n_ref;
79
80         bool dead:1;
81         pid_t original_pid;
82
83         int fds[_FD_MAX];
84
85         pthread_t workers[WORKERS_MAX];
86         unsigned n_valid_workers;
87
88         unsigned current_id;
89         sd_resolve_query* query_array[QUERIES_MAX];
90         unsigned n_queries, n_done, n_outstanding;
91
92         sd_event_source *event_source;
93         sd_event *event;
94
95         sd_resolve_query *current;
96
97         sd_resolve **default_resolve_ptr;
98         pid_t tid;
99
100         LIST_HEAD(sd_resolve_query, queries);
101 };
102
103 struct sd_resolve_query {
104         unsigned n_ref;
105
106         sd_resolve *resolve;
107
108         QueryType type:4;
109         bool done:1;
110         bool floating:1;
111         unsigned id;
112
113         int ret;
114         int _errno;
115         int _h_errno;
116         struct addrinfo *addrinfo;
117         char *serv, *host;
118         unsigned char *answer;
119
120         union {
121                 sd_resolve_getaddrinfo_handler_t getaddrinfo_handler;
122                 sd_resolve_getnameinfo_handler_t getnameinfo_handler;
123                 sd_resolve_res_handler_t res_handler;
124         };
125
126         void *userdata;
127
128         LIST_FIELDS(sd_resolve_query, queries);
129 };
130
131 typedef struct RHeader {
132         QueryType type;
133         unsigned id;
134         size_t length;
135 } RHeader;
136
137 typedef struct AddrInfoRequest {
138         struct RHeader header;
139         bool hints_valid;
140         int ai_flags;
141         int ai_family;
142         int ai_socktype;
143         int ai_protocol;
144         size_t node_len, service_len;
145 } AddrInfoRequest;
146
147 typedef struct AddrInfoResponse {
148         struct RHeader header;
149         int ret;
150         int _errno;
151         int _h_errno;
152         /* followed by addrinfo_serialization[] */
153 } AddrInfoResponse;
154
155 typedef struct AddrInfoSerialization {
156         int ai_flags;
157         int ai_family;
158         int ai_socktype;
159         int ai_protocol;
160         size_t ai_addrlen;
161         size_t canonname_len;
162         /* Followed by ai_addr amd ai_canonname with variable lengths */
163 } AddrInfoSerialization;
164
165 typedef struct NameInfoRequest {
166         struct RHeader header;
167         int flags;
168         socklen_t sockaddr_len;
169         bool gethost:1, getserv:1;
170 } NameInfoRequest;
171
172 typedef struct NameInfoResponse {
173         struct RHeader header;
174         size_t hostlen, servlen;
175         int ret;
176         int _errno;
177         int _h_errno;
178 } NameInfoResponse;
179
180 typedef struct ResRequest {
181         struct RHeader header;
182         int class;
183         int type;
184         size_t dname_len;
185 } ResRequest;
186
187 typedef struct ResResponse {
188         struct RHeader header;
189         int ret;
190         int _errno;
191         int _h_errno;
192 } ResResponse;
193
194 typedef union Packet {
195         RHeader rheader;
196         AddrInfoRequest addrinfo_request;
197         AddrInfoResponse addrinfo_response;
198         NameInfoRequest nameinfo_request;
199         NameInfoResponse nameinfo_response;
200         ResRequest res_request;
201         ResResponse res_response;
202 } Packet;
203
204 static int getaddrinfo_done(sd_resolve_query* q);
205 static int getnameinfo_done(sd_resolve_query *q);
206 static int res_query_done(sd_resolve_query* q);
207
208 static void resolve_query_disconnect(sd_resolve_query *q);
209
210 #define RESOLVE_DONT_DESTROY(resolve) \
211         _cleanup_resolve_unref_ _unused_ sd_resolve *_dont_destroy_##resolve = sd_resolve_ref(resolve)
212
213 static int send_died(int out_fd) {
214
215         RHeader rh = {
216                 .type = RESPONSE_DIED,
217                 .length = sizeof(RHeader),
218         };
219
220         assert(out_fd >= 0);
221
222         if (send(out_fd, &rh, rh.length, MSG_NOSIGNAL) < 0)
223                 return -errno;
224
225         return 0;
226 }
227
228 static void *serialize_addrinfo(void *p, const struct addrinfo *ai, size_t *length, size_t maxlength) {
229         AddrInfoSerialization s;
230         size_t cnl, l;
231
232         assert(p);
233         assert(ai);
234         assert(length);
235         assert(*length <= maxlength);
236
237         cnl = ai->ai_canonname ? strlen(ai->ai_canonname)+1 : 0;
238         l = sizeof(AddrInfoSerialization) + ai->ai_addrlen + cnl;
239
240         if (*length + l > maxlength)
241                 return NULL;
242
243         s.ai_flags = ai->ai_flags;
244         s.ai_family = ai->ai_family;
245         s.ai_socktype = ai->ai_socktype;
246         s.ai_protocol = ai->ai_protocol;
247         s.ai_addrlen = ai->ai_addrlen;
248         s.canonname_len = cnl;
249
250         memcpy((uint8_t*) p, &s, sizeof(AddrInfoSerialization));
251         memcpy((uint8_t*) p + sizeof(AddrInfoSerialization), ai->ai_addr, ai->ai_addrlen);
252
253         if (ai->ai_canonname)
254                 memcpy((char*) p + sizeof(AddrInfoSerialization) + ai->ai_addrlen, ai->ai_canonname, cnl);
255
256         *length += l;
257         return (uint8_t*) p + l;
258 }
259
260 static int send_addrinfo_reply(
261                 int out_fd,
262                 unsigned id,
263                 int ret,
264                 struct addrinfo *ai,
265                 int _errno,
266                 int _h_errno) {
267
268         AddrInfoResponse resp = {
269                 .header.type = RESPONSE_ADDRINFO,
270                 .header.id = id,
271                 .header.length = sizeof(AddrInfoResponse),
272                 .ret = ret,
273                 ._errno = _errno,
274                 ._h_errno = _h_errno,
275         };
276
277         struct msghdr mh = {};
278         struct iovec iov[2];
279         union {
280                 AddrInfoSerialization ais;
281                 uint8_t space[BUFSIZE];
282         } buffer;
283
284         assert(out_fd >= 0);
285
286         if (ret == 0 && ai) {
287                 void *p = &buffer;
288                 struct addrinfo *k;
289
290                 for (k = ai; k; k = k->ai_next) {
291                         p = serialize_addrinfo(p, k, &resp.header.length, (uint8_t*) &buffer + BUFSIZE - (uint8_t*) p);
292                         if (!p) {
293                                 freeaddrinfo(ai);
294                                 return -ENOBUFS;
295                         }
296                 }
297         }
298
299         if (ai)
300                 freeaddrinfo(ai);
301
302         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(AddrInfoResponse) };
303         iov[1] = (struct iovec) { .iov_base = &buffer, .iov_len = resp.header.length - sizeof(AddrInfoResponse) };
304
305         mh.msg_iov = iov;
306         mh.msg_iovlen = ELEMENTSOF(iov);
307
308         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
309                 return -errno;
310
311         return 0;
312 }
313
314 static int send_nameinfo_reply(
315                 int out_fd,
316                 unsigned id,
317                 int ret,
318                 const char *host,
319                 const char *serv,
320                 int _errno,
321                 int _h_errno) {
322
323         NameInfoResponse resp = {
324                 .header.type = RESPONSE_NAMEINFO,
325                 .header.id = id,
326                 .ret = ret,
327                 ._errno = _errno,
328                 ._h_errno = _h_errno,
329         };
330
331         struct msghdr mh = {};
332         struct iovec iov[3];
333         size_t hl, sl;
334
335         assert(out_fd >= 0);
336
337         sl = serv ? strlen(serv)+1 : 0;
338         hl = host ? strlen(host)+1 : 0;
339
340         resp.header.length = sizeof(NameInfoResponse) + hl + sl;
341         resp.hostlen = hl;
342         resp.servlen = sl;
343
344         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(NameInfoResponse) };
345         iov[1] = (struct iovec) { .iov_base = (void*) host, .iov_len = hl };
346         iov[2] = (struct iovec) { .iov_base = (void*) serv, .iov_len = sl };
347
348         mh.msg_iov = iov;
349         mh.msg_iovlen = ELEMENTSOF(iov);
350
351         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
352                 return -errno;
353
354         return 0;
355 }
356
357 static int send_res_reply(int out_fd, unsigned id, const unsigned char *answer, int ret, int _errno, int _h_errno) {
358
359         ResResponse resp = {
360                 .header.type = RESPONSE_RES,
361                 .header.id = id,
362                 .ret = ret,
363                 ._errno = _errno,
364                 ._h_errno = _h_errno,
365         };
366
367         struct msghdr mh = {};
368         struct iovec iov[2];
369         size_t l;
370
371         assert(out_fd >= 0);
372
373         l = ret > 0 ? (size_t) ret : 0;
374
375         resp.header.length = sizeof(ResResponse) + l;
376
377         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(ResResponse) };
378         iov[1] = (struct iovec) { .iov_base = (void*) answer, .iov_len = l };
379
380         mh.msg_iov = iov;
381         mh.msg_iovlen = ELEMENTSOF(iov);
382
383         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
384                 return -errno;
385
386         return 0;
387 }
388
389 static int handle_request(int out_fd, const Packet *packet, size_t length) {
390         const RHeader *req;
391
392         assert(out_fd >= 0);
393         assert(packet);
394
395         req = &packet->rheader;
396
397         assert(length >= sizeof(RHeader));
398         assert(length == req->length);
399
400         switch (req->type) {
401
402         case REQUEST_ADDRINFO: {
403                const AddrInfoRequest *ai_req = &packet->addrinfo_request;
404                struct addrinfo hints = {}, *result = NULL;
405                const char *node, *service;
406                int ret;
407
408                assert(length >= sizeof(AddrInfoRequest));
409                assert(length == sizeof(AddrInfoRequest) + ai_req->node_len + ai_req->service_len);
410
411                hints.ai_flags = ai_req->ai_flags;
412                hints.ai_family = ai_req->ai_family;
413                hints.ai_socktype = ai_req->ai_socktype;
414                hints.ai_protocol = ai_req->ai_protocol;
415
416                node = ai_req->node_len ? (const char*) ai_req + sizeof(AddrInfoRequest) : NULL;
417                service = ai_req->service_len ? (const char*) ai_req + sizeof(AddrInfoRequest) + ai_req->node_len : NULL;
418
419                ret = getaddrinfo(
420                                node, service,
421                                ai_req->hints_valid ? &hints : NULL,
422                                &result);
423
424                /* send_addrinfo_reply() frees result */
425                return send_addrinfo_reply(out_fd, req->id, ret, result, errno, h_errno);
426         }
427
428         case REQUEST_NAMEINFO: {
429                const NameInfoRequest *ni_req = &packet->nameinfo_request;
430                char hostbuf[NI_MAXHOST], servbuf[NI_MAXSERV];
431                union sockaddr_union sa;
432                int ret;
433
434                assert(length >= sizeof(NameInfoRequest));
435                assert(length == sizeof(NameInfoRequest) + ni_req->sockaddr_len);
436                assert(sizeof(sa) >= ni_req->sockaddr_len);
437
438                memcpy(&sa, (const uint8_t *) ni_req + sizeof(NameInfoRequest), ni_req->sockaddr_len);
439
440                ret = getnameinfo(&sa.sa, ni_req->sockaddr_len,
441                                ni_req->gethost ? hostbuf : NULL, ni_req->gethost ? sizeof(hostbuf) : 0,
442                                ni_req->getserv ? servbuf : NULL, ni_req->getserv ? sizeof(servbuf) : 0,
443                                ni_req->flags);
444
445                return send_nameinfo_reply(out_fd, req->id, ret,
446                                ret == 0 && ni_req->gethost ? hostbuf : NULL,
447                                ret == 0 && ni_req->getserv ? servbuf : NULL,
448                                errno, h_errno);
449         }
450
451         case REQUEST_RES_QUERY:
452         case REQUEST_RES_SEARCH: {
453                  const ResRequest *res_req = &packet->res_request;
454                  union {
455                          HEADER header;
456                          uint8_t space[BUFSIZE];
457                  } answer;
458                  const char *dname;
459                  int ret;
460
461                  assert(length >= sizeof(ResRequest));
462                  assert(length == sizeof(ResRequest) + res_req->dname_len);
463
464                  dname = (const char *) req + sizeof(ResRequest);
465
466                  if (req->type == REQUEST_RES_QUERY)
467                          ret = res_query(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
468                  else
469                          ret = res_search(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
470
471                  return send_res_reply(out_fd, req->id, (unsigned char *) &answer, ret, errno, h_errno);
472         }
473
474         case REQUEST_TERMINATE:
475                  /* Quit */
476                  return -ECONNRESET;
477
478         default:
479                 assert_not_reached("Unknown request");
480         }
481
482         return 0;
483 }
484
485 static void* thread_worker(void *p) {
486         sd_resolve *resolve = p;
487         sigset_t fullset;
488
489         /* No signals in this thread please */
490         assert_se(sigfillset(&fullset) == 0);
491         assert_se(pthread_sigmask(SIG_BLOCK, &fullset, NULL) == 0);
492
493         /* Assign a pretty name to this thread */
494         prctl(PR_SET_NAME, (unsigned long) "sd-resolve");
495
496         while (!resolve->dead) {
497                 union {
498                         Packet packet;
499                         uint8_t space[BUFSIZE];
500                 } buf;
501                 ssize_t length;
502
503                 length = recv(resolve->fds[REQUEST_RECV_FD], &buf, sizeof(buf), 0);
504                 if (length < 0) {
505                         if (errno == EINTR)
506                                 continue;
507
508                         break;
509                 }
510                 if (length == 0)
511                         break;
512
513                 if (resolve->dead)
514                         break;
515
516                 if (handle_request(resolve->fds[RESPONSE_SEND_FD], &buf.packet, (size_t) length) < 0)
517                         break;
518         }
519
520         send_died(resolve->fds[RESPONSE_SEND_FD]);
521
522         return NULL;
523 }
524
525 static int start_threads(sd_resolve *resolve, unsigned extra) {
526         unsigned n;
527         int r;
528
529         n = resolve->n_outstanding + extra;
530         n = CLAMP(n, WORKERS_MIN, WORKERS_MAX);
531
532         while (resolve->n_valid_workers < n) {
533
534                 r = pthread_create(&resolve->workers[resolve->n_valid_workers], NULL, thread_worker, resolve);
535                 if (r != 0)
536                         return -r;
537
538                 resolve->n_valid_workers ++;
539         }
540
541         return 0;
542 }
543
544 static bool resolve_pid_changed(sd_resolve *r) {
545         assert(r);
546
547         /* We don't support people creating a resolver and keeping it
548          * around after fork(). Let's complain. */
549
550         return r->original_pid != getpid();
551 }
552
553 _public_ int sd_resolve_new(sd_resolve **ret) {
554         sd_resolve *resolve = NULL;
555         int i, r;
556
557         assert_return(ret, -EINVAL);
558
559         resolve = new0(sd_resolve, 1);
560         if (!resolve)
561                 return -ENOMEM;
562
563         resolve->n_ref = 1;
564         resolve->original_pid = getpid();
565
566         for (i = 0; i < _FD_MAX; i++)
567                 resolve->fds[i] = -1;
568
569         r = socketpair(PF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0, resolve->fds + REQUEST_RECV_FD);
570         if (r < 0) {
571                 r = -errno;
572                 goto fail;
573         }
574
575         r = socketpair(PF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0, resolve->fds + RESPONSE_RECV_FD);
576         if (r < 0) {
577                 r = -errno;
578                 goto fail;
579         }
580
581         fd_inc_sndbuf(resolve->fds[REQUEST_SEND_FD], QUERIES_MAX * BUFSIZE);
582         fd_inc_rcvbuf(resolve->fds[REQUEST_RECV_FD], QUERIES_MAX * BUFSIZE);
583         fd_inc_sndbuf(resolve->fds[RESPONSE_SEND_FD], QUERIES_MAX * BUFSIZE);
584         fd_inc_rcvbuf(resolve->fds[RESPONSE_RECV_FD], QUERIES_MAX * BUFSIZE);
585
586         fd_nonblock(resolve->fds[RESPONSE_RECV_FD], true);
587
588         *ret = resolve;
589         return 0;
590
591 fail:
592         sd_resolve_unref(resolve);
593         return r;
594 }
595
596 _public_ int sd_resolve_default(sd_resolve **ret) {
597
598         static thread_local sd_resolve *default_resolve = NULL;
599         sd_resolve *e = NULL;
600         int r;
601
602         if (!ret)
603                 return !!default_resolve;
604
605         if (default_resolve) {
606                 *ret = sd_resolve_ref(default_resolve);
607                 return 0;
608         }
609
610         r = sd_resolve_new(&e);
611         if (r < 0)
612                 return r;
613
614         e->default_resolve_ptr = &default_resolve;
615         e->tid = gettid();
616         default_resolve = e;
617
618         *ret = e;
619         return 1;
620 }
621
622 _public_ int sd_resolve_get_tid(sd_resolve *resolve, pid_t *tid) {
623         assert_return(resolve, -EINVAL);
624         assert_return(tid, -EINVAL);
625         assert_return(!resolve_pid_changed(resolve), -ECHILD);
626
627         if (resolve->tid != 0) {
628                 *tid = resolve->tid;
629                 return 0;
630         }
631
632         if (resolve->event)
633                 return sd_event_get_tid(resolve->event, tid);
634
635         return -ENXIO;
636 }
637
638 static void resolve_free(sd_resolve *resolve) {
639         PROTECT_ERRNO;
640         sd_resolve_query *q;
641         unsigned i;
642
643         assert(resolve);
644
645         while ((q = resolve->queries)) {
646                 assert(q->floating);
647                 resolve_query_disconnect(q);
648                 sd_resolve_query_unref(q);
649         }
650
651         if (resolve->default_resolve_ptr)
652                 *(resolve->default_resolve_ptr) = NULL;
653
654         resolve->dead = true;
655
656         sd_resolve_detach_event(resolve);
657
658         if (resolve->fds[REQUEST_SEND_FD] >= 0) {
659
660                 RHeader req = {
661                         .type = REQUEST_TERMINATE,
662                         .length = sizeof(req)
663                 };
664
665                 /* Send one termination packet for each worker */
666                 for (i = 0; i < resolve->n_valid_workers; i++)
667                         send(resolve->fds[REQUEST_SEND_FD], &req, req.length, MSG_NOSIGNAL);
668         }
669
670         /* Now terminate them and wait until they are gone. */
671         for (i = 0; i < resolve->n_valid_workers; i++) {
672                 for (;;) {
673                         if (pthread_join(resolve->workers[i], NULL) != EINTR)
674                                 break;
675                 }
676         }
677
678         /* Close all communication channels */
679         for (i = 0; i < _FD_MAX; i++)
680                 safe_close(resolve->fds[i]);
681
682         free(resolve);
683 }
684
685 _public_ sd_resolve* sd_resolve_ref(sd_resolve *resolve) {
686         assert_return(resolve, NULL);
687
688         assert(resolve->n_ref >= 1);
689         resolve->n_ref++;
690
691         return resolve;
692 }
693
694 _public_ sd_resolve* sd_resolve_unref(sd_resolve *resolve) {
695
696         if (!resolve)
697                 return NULL;
698
699         assert(resolve->n_ref >= 1);
700         resolve->n_ref--;
701
702         if (resolve->n_ref <= 0)
703                 resolve_free(resolve);
704
705         return NULL;
706 }
707
708 _public_ int sd_resolve_get_fd(sd_resolve *resolve) {
709         assert_return(resolve, -EINVAL);
710         assert_return(!resolve_pid_changed(resolve), -ECHILD);
711
712         return resolve->fds[RESPONSE_RECV_FD];
713 }
714
715 _public_ int sd_resolve_get_events(sd_resolve *resolve) {
716         assert_return(resolve, -EINVAL);
717         assert_return(!resolve_pid_changed(resolve), -ECHILD);
718
719         return resolve->n_queries > resolve->n_done ? POLLIN : 0;
720 }
721
722 _public_ int sd_resolve_get_timeout(sd_resolve *resolve, uint64_t *usec) {
723         assert_return(resolve, -EINVAL);
724         assert_return(usec, -EINVAL);
725         assert_return(!resolve_pid_changed(resolve), -ECHILD);
726
727         *usec = (uint64_t) -1;
728         return 0;
729 }
730
731 static sd_resolve_query *lookup_query(sd_resolve *resolve, unsigned id) {
732         sd_resolve_query *q;
733
734         assert(resolve);
735
736         q = resolve->query_array[id % QUERIES_MAX];
737         if (q)
738                 if (q->id == id)
739                         return q;
740
741         return NULL;
742 }
743
744 static int complete_query(sd_resolve *resolve, sd_resolve_query *q) {
745         int r;
746
747         assert(q);
748         assert(!q->done);
749         assert(q->resolve == resolve);
750
751         q->done = true;
752         resolve->n_done ++;
753
754         resolve->current = sd_resolve_query_ref(q);
755
756         switch (q->type) {
757
758         case REQUEST_ADDRINFO:
759                 r = getaddrinfo_done(q);
760                 break;
761
762         case REQUEST_NAMEINFO:
763                 r = getnameinfo_done(q);
764                 break;
765
766         case REQUEST_RES_QUERY:
767         case REQUEST_RES_SEARCH:
768                 r = res_query_done(q);
769                 break;
770
771         default:
772                 assert_not_reached("Cannot complete unknown query type");
773         }
774
775         resolve->current = NULL;
776
777         if (q->floating) {
778                 resolve_query_disconnect(q);
779                 sd_resolve_query_unref(q);
780         }
781
782         sd_resolve_query_unref(q);
783
784         return r;
785 }
786
787 static int unserialize_addrinfo(const void **p, size_t *length, struct addrinfo **ret_ai) {
788         AddrInfoSerialization s;
789         size_t l;
790         struct addrinfo *ai;
791
792         assert(p);
793         assert(*p);
794         assert(ret_ai);
795         assert(length);
796
797         if (*length < sizeof(AddrInfoSerialization))
798                 return -EBADMSG;
799
800         memcpy(&s, *p, sizeof(s));
801
802         l = sizeof(AddrInfoSerialization) + s.ai_addrlen + s.canonname_len;
803         if (*length < l)
804                 return -EBADMSG;
805
806         ai = new0(struct addrinfo, 1);
807         if (!ai)
808                 return -ENOMEM;
809
810         ai->ai_flags = s.ai_flags;
811         ai->ai_family = s.ai_family;
812         ai->ai_socktype = s.ai_socktype;
813         ai->ai_protocol = s.ai_protocol;
814         ai->ai_addrlen = s.ai_addrlen;
815
816         if (s.ai_addrlen > 0) {
817                 ai->ai_addr = memdup((const uint8_t*) *p + sizeof(AddrInfoSerialization), s.ai_addrlen);
818                 if (!ai->ai_addr) {
819                         free(ai);
820                         return -ENOMEM;
821                 }
822         }
823
824         if (s.canonname_len > 0) {
825                 ai->ai_canonname = memdup((const uint8_t*) *p + sizeof(AddrInfoSerialization) + s.ai_addrlen, s.canonname_len);
826                 if (!ai->ai_canonname) {
827                         free(ai->ai_addr);
828                         free(ai);
829                         return -ENOMEM;
830                 }
831         }
832
833         *length -= l;
834         *ret_ai = ai;
835         *p = ((const uint8_t*) *p) + l;
836
837         return 0;
838 }
839
840 static int handle_response(sd_resolve *resolve, const Packet *packet, size_t length) {
841         const RHeader *resp;
842         sd_resolve_query *q;
843         int r;
844
845         assert(resolve);
846
847         resp = &packet->rheader;
848         assert(resp);
849         assert(length >= sizeof(RHeader));
850         assert(length == resp->length);
851
852         if (resp->type == RESPONSE_DIED) {
853                 resolve->dead = true;
854                 return 0;
855         }
856
857         assert(resolve->n_outstanding > 0);
858         resolve->n_outstanding--;
859
860         q = lookup_query(resolve, resp->id);
861         if (!q)
862                 return 0;
863
864         switch (resp->type) {
865
866         case RESPONSE_ADDRINFO: {
867                 const AddrInfoResponse *ai_resp = &packet->addrinfo_response;
868                 const void *p;
869                 size_t l;
870                 struct addrinfo *prev = NULL;
871
872                 assert(length >= sizeof(AddrInfoResponse));
873                 assert(q->type == REQUEST_ADDRINFO);
874
875                 q->ret = ai_resp->ret;
876                 q->_errno = ai_resp->_errno;
877                 q->_h_errno = ai_resp->_h_errno;
878
879                 l = length - sizeof(AddrInfoResponse);
880                 p = (const uint8_t*) resp + sizeof(AddrInfoResponse);
881
882                 while (l > 0 && p) {
883                         struct addrinfo *ai = NULL;
884
885                         r = unserialize_addrinfo(&p, &l, &ai);
886                         if (r < 0) {
887                                 q->ret = EAI_SYSTEM;
888                                 q->_errno = -r;
889                                 q->_h_errno = 0;
890                                 freeaddrinfo(q->addrinfo);
891                                 q->addrinfo = NULL;
892                                 break;
893                         }
894
895                         if (prev)
896                                 prev->ai_next = ai;
897                         else
898                                 q->addrinfo = ai;
899
900                         prev = ai;
901                 }
902
903                 return complete_query(resolve, q);
904         }
905
906         case RESPONSE_NAMEINFO: {
907                 const NameInfoResponse *ni_resp = &packet->nameinfo_response;
908
909                 assert(length >= sizeof(NameInfoResponse));
910                 assert(q->type == REQUEST_NAMEINFO);
911
912                 q->ret = ni_resp->ret;
913                 q->_errno = ni_resp->_errno;
914                 q->_h_errno = ni_resp->_h_errno;
915
916                 if (ni_resp->hostlen > 0) {
917                         q->host = strndup((const char*) ni_resp + sizeof(NameInfoResponse), ni_resp->hostlen-1);
918                         if (!q->host) {
919                                 q->ret = EAI_MEMORY;
920                                 q->_errno = ENOMEM;
921                                 q->_h_errno = 0;
922                         }
923                 }
924
925                 if (ni_resp->servlen > 0) {
926                         q->serv = strndup((const char*) ni_resp + sizeof(NameInfoResponse) + ni_resp->hostlen, ni_resp->servlen-1);
927                         if (!q->serv) {
928                                 q->ret = EAI_MEMORY;
929                                 q->_errno = ENOMEM;
930                                 q->_h_errno = 0;
931                         }
932                 }
933
934                 return complete_query(resolve, q);
935         }
936
937         case RESPONSE_RES: {
938                 const ResResponse *res_resp = &packet->res_response;
939
940                 assert(length >= sizeof(ResResponse));
941                 assert(q->type == REQUEST_RES_QUERY || q->type == REQUEST_RES_SEARCH);
942
943                 q->ret = res_resp->ret;
944                 q->_errno = res_resp->_errno;
945                 q->_h_errno = res_resp->_h_errno;
946
947                 if (res_resp->ret >= 0)  {
948                         q->answer = memdup((const char *)resp + sizeof(ResResponse), res_resp->ret);
949                         if (!q->answer) {
950                                 q->ret = -1;
951                                 q->_errno = ENOMEM;
952                                 q->_h_errno = 0;
953                         }
954                 }
955
956                 return complete_query(resolve, q);
957         }
958
959         default:
960                 return 0;
961         }
962 }
963
964 _public_ int sd_resolve_process(sd_resolve *resolve) {
965         RESOLVE_DONT_DESTROY(resolve);
966
967         union {
968                 Packet packet;
969                 uint8_t space[BUFSIZE];
970         } buf;
971         ssize_t l;
972         int r;
973
974         assert_return(resolve, -EINVAL);
975         assert_return(!resolve_pid_changed(resolve), -ECHILD);
976
977         /* We don't allow recursively invoking sd_resolve_process(). */
978         assert_return(!resolve->current, -EBUSY);
979
980         l = recv(resolve->fds[RESPONSE_RECV_FD], &buf, sizeof(buf), 0);
981         if (l < 0) {
982                 if (errno == EAGAIN)
983                         return 0;
984
985                 return -errno;
986         }
987         if (l == 0)
988                 return -ECONNREFUSED;
989
990         r = handle_response(resolve, &buf.packet, (size_t) l);
991         if (r < 0)
992                 return r;
993
994         return 1;
995 }
996
997 _public_ int sd_resolve_wait(sd_resolve *resolve, uint64_t timeout_usec) {
998         int r;
999
1000         assert_return(resolve, -EINVAL);
1001         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1002
1003         if (resolve->n_done >= resolve->n_queries)
1004                 return 0;
1005
1006         do {
1007                 r = fd_wait_for_event(resolve->fds[RESPONSE_RECV_FD], POLLIN, timeout_usec);
1008         } while (r == -EINTR);
1009
1010         if (r < 0)
1011                 return r;
1012
1013         return sd_resolve_process(resolve);
1014 }
1015
1016 static int alloc_query(sd_resolve *resolve, bool floating, sd_resolve_query **_q) {
1017         sd_resolve_query *q;
1018         int r;
1019
1020         assert(resolve);
1021         assert(_q);
1022
1023         if (resolve->n_queries >= QUERIES_MAX)
1024                 return -ENOBUFS;
1025
1026         r = start_threads(resolve, 1);
1027         if (r < 0)
1028                 return r;
1029
1030         while (resolve->query_array[resolve->current_id % QUERIES_MAX])
1031                 resolve->current_id++;
1032
1033         q = resolve->query_array[resolve->current_id % QUERIES_MAX] = new0(sd_resolve_query, 1);
1034         if (!q)
1035                 return -ENOMEM;
1036
1037         q->n_ref = 1;
1038         q->resolve = resolve;
1039         q->floating = floating;
1040         q->id = resolve->current_id++;
1041
1042         if (!floating)
1043                 sd_resolve_ref(resolve);
1044
1045         LIST_PREPEND(queries, resolve->queries, q);
1046         resolve->n_queries++;
1047
1048         *_q = q;
1049         return 0;
1050 }
1051
1052 _public_ int sd_resolve_getaddrinfo(
1053                 sd_resolve *resolve,
1054                 sd_resolve_query **_q,
1055                 const char *node, const char *service,
1056                 const struct addrinfo *hints,
1057                 sd_resolve_getaddrinfo_handler_t callback, void *userdata) {
1058
1059         AddrInfoRequest req = {};
1060         struct msghdr mh = {};
1061         struct iovec iov[3];
1062         sd_resolve_query *q;
1063         int r;
1064
1065         assert_return(resolve, -EINVAL);
1066         assert_return(node || service, -EINVAL);
1067         assert_return(callback, -EINVAL);
1068         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1069
1070         r = alloc_query(resolve, !_q, &q);
1071         if (r < 0)
1072                 return r;
1073
1074         q->type = REQUEST_ADDRINFO;
1075         q->getaddrinfo_handler = callback;
1076         q->userdata = userdata;
1077
1078         req.node_len = node ? strlen(node)+1 : 0;
1079         req.service_len = service ? strlen(service)+1 : 0;
1080
1081         req.header.id = q->id;
1082         req.header.type = REQUEST_ADDRINFO;
1083         req.header.length = sizeof(AddrInfoRequest) + req.node_len + req.service_len;
1084
1085         if (hints) {
1086                 req.hints_valid = true;
1087                 req.ai_flags = hints->ai_flags;
1088                 req.ai_family = hints->ai_family;
1089                 req.ai_socktype = hints->ai_socktype;
1090                 req.ai_protocol = hints->ai_protocol;
1091         }
1092
1093         iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(AddrInfoRequest) };
1094         if (node)
1095                 iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = (void*) node, .iov_len = req.node_len };
1096         if (service)
1097                 iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = (void*) service, .iov_len = req.service_len };
1098         mh.msg_iov = iov;
1099
1100         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1101                 sd_resolve_query_unref(q);
1102                 return -errno;
1103         }
1104
1105         resolve->n_outstanding++;
1106
1107         if (_q)
1108                 *_q = q;
1109
1110         return 0;
1111 }
1112
1113 static int getaddrinfo_done(sd_resolve_query* q) {
1114         assert(q);
1115         assert(q->done);
1116         assert(q->getaddrinfo_handler);
1117
1118         errno = q->_errno;
1119         h_errno = q->_h_errno;
1120
1121         return q->getaddrinfo_handler(q, q->ret, q->addrinfo, q->userdata);
1122 }
1123
1124 _public_ int sd_resolve_getnameinfo(
1125                 sd_resolve *resolve,
1126                 sd_resolve_query**_q,
1127                 const struct sockaddr *sa, socklen_t salen,
1128                 int flags,
1129                 uint64_t get,
1130                 sd_resolve_getnameinfo_handler_t callback,
1131                 void *userdata) {
1132
1133         NameInfoRequest req = {};
1134         struct msghdr mh = {};
1135         struct iovec iov[2];
1136         sd_resolve_query *q;
1137         int r;
1138
1139         assert_return(resolve, -EINVAL);
1140         assert_return(sa, -EINVAL);
1141         assert_return(salen >= sizeof(struct sockaddr), -EINVAL);
1142         assert_return(salen <= sizeof(union sockaddr_union), -EINVAL);
1143         assert_return((get & ~SD_RESOLVE_GET_BOTH) == 0, -EINVAL);
1144         assert_return(callback, -EINVAL);
1145         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1146
1147         r = alloc_query(resolve, !_q, &q);
1148         if (r < 0)
1149                 return r;
1150
1151         q->type = REQUEST_NAMEINFO;
1152         q->getnameinfo_handler = callback;
1153         q->userdata = userdata;
1154
1155         req.header.id = q->id;
1156         req.header.type = REQUEST_NAMEINFO;
1157         req.header.length = sizeof(NameInfoRequest) + salen;
1158
1159         req.flags = flags;
1160         req.sockaddr_len = salen;
1161         req.gethost = !!(get & SD_RESOLVE_GET_HOST);
1162         req.getserv = !!(get & SD_RESOLVE_GET_SERVICE);
1163
1164         iov[0] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(NameInfoRequest) };
1165         iov[1] = (struct iovec) { .iov_base = (void*) sa, .iov_len = salen };
1166
1167         mh.msg_iov = iov;
1168         mh.msg_iovlen = 2;
1169
1170         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1171                 sd_resolve_query_unref(q);
1172                 return -errno;
1173         }
1174
1175         resolve->n_outstanding++;
1176
1177         if (_q)
1178                 *_q = q;
1179
1180         return 0;
1181 }
1182
1183 static int getnameinfo_done(sd_resolve_query *q) {
1184
1185         assert(q);
1186         assert(q->done);
1187         assert(q->getnameinfo_handler);
1188
1189         errno = q->_errno;
1190         h_errno= q->_h_errno;
1191
1192         return q->getnameinfo_handler(q, q->ret, q->host, q->serv, q->userdata);
1193 }
1194
1195 static int resolve_res(
1196                 sd_resolve *resolve,
1197                 sd_resolve_query **_q,
1198                 QueryType qtype,
1199                 const char *dname,
1200                 int class, int type,
1201                 sd_resolve_res_handler_t callback, void *userdata) {
1202
1203         struct msghdr mh = {};
1204         struct iovec iov[2];
1205         ResRequest req = {};
1206         sd_resolve_query *q;
1207         int r;
1208
1209         assert_return(resolve, -EINVAL);
1210         assert_return(dname, -EINVAL);
1211         assert_return(callback, -EINVAL);
1212         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1213
1214         r = alloc_query(resolve, !_q, &q);
1215         if (r < 0)
1216                 return r;
1217
1218         q->type = qtype;
1219         q->res_handler = callback;
1220         q->userdata = userdata;
1221
1222         req.dname_len = strlen(dname) + 1;
1223         req.class = class;
1224         req.type = type;
1225
1226         req.header.id = q->id;
1227         req.header.type = qtype;
1228         req.header.length = sizeof(ResRequest) + req.dname_len;
1229
1230         iov[0] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(ResRequest) };
1231         iov[1] = (struct iovec) { .iov_base = (void*) dname, .iov_len = req.dname_len };
1232
1233         mh.msg_iov = iov;
1234         mh.msg_iovlen = 2;
1235
1236         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1237                 sd_resolve_query_unref(q);
1238                 return -errno;
1239         }
1240
1241         resolve->n_outstanding++;
1242
1243         if (_q)
1244                 *_q = q;
1245
1246         return 0;
1247 }
1248
1249 _public_ int sd_resolve_res_query(sd_resolve *resolve, sd_resolve_query** q, const char *dname, int class, int type, sd_resolve_res_handler_t callback, void *userdata) {
1250         return resolve_res(resolve, q, REQUEST_RES_QUERY, dname, class, type, callback, userdata);
1251 }
1252
1253 _public_ int sd_resolve_res_search(sd_resolve *resolve, sd_resolve_query** q, const char *dname, int class, int type, sd_resolve_res_handler_t callback, void *userdata) {
1254         return resolve_res(resolve, q, REQUEST_RES_SEARCH, dname, class, type, callback, userdata);
1255 }
1256
1257 static int res_query_done(sd_resolve_query* q) {
1258         assert(q);
1259         assert(q->done);
1260         assert(q->res_handler);
1261
1262         errno = q->_errno;
1263         h_errno = q->_h_errno;
1264
1265         return q->res_handler(q, q->ret, q->answer, q->userdata);
1266 }
1267
1268 _public_ sd_resolve_query* sd_resolve_query_ref(sd_resolve_query *q) {
1269         assert_return(q, NULL);
1270
1271         assert(q->n_ref >= 1);
1272         q->n_ref++;
1273
1274         return q;
1275 }
1276
1277 static void resolve_freeaddrinfo(struct addrinfo *ai) {
1278         while (ai) {
1279                 struct addrinfo *next = ai->ai_next;
1280
1281                 free(ai->ai_addr);
1282                 free(ai->ai_canonname);
1283                 free(ai);
1284                 ai = next;
1285         }
1286 }
1287
1288 static void resolve_query_disconnect(sd_resolve_query *q) {
1289         sd_resolve *resolve;
1290         unsigned i;
1291
1292         assert(q);
1293
1294         if (!q->resolve)
1295                 return;
1296
1297         resolve = q->resolve;
1298         assert(resolve->n_queries > 0);
1299
1300         if (q->done) {
1301                 assert(resolve->n_done > 0);
1302                 resolve->n_done--;
1303         }
1304
1305         i = q->id % QUERIES_MAX;
1306         assert(resolve->query_array[i] == q);
1307         resolve->query_array[i] = NULL;
1308         LIST_REMOVE(queries, resolve->queries, q);
1309         resolve->n_queries--;
1310
1311         q->resolve = NULL;
1312         if (!q->floating)
1313                 sd_resolve_unref(resolve);
1314 }
1315
1316 static void resolve_query_free(sd_resolve_query *q) {
1317         assert(q);
1318
1319         resolve_query_disconnect(q);
1320
1321         resolve_freeaddrinfo(q->addrinfo);
1322         free(q->host);
1323         free(q->serv);
1324         free(q->answer);
1325         free(q);
1326 }
1327
1328 _public_ sd_resolve_query* sd_resolve_query_unref(sd_resolve_query* q) {
1329         if (!q)
1330                 return NULL;
1331
1332         assert(q->n_ref >= 1);
1333         q->n_ref--;
1334
1335         if (q->n_ref <= 0)
1336                 resolve_query_free(q);
1337
1338         return NULL;
1339 }
1340
1341 _public_ int sd_resolve_query_is_done(sd_resolve_query *q) {
1342         assert_return(q, -EINVAL);
1343         assert_return(!resolve_pid_changed(q->resolve), -ECHILD);
1344
1345         return q->done;
1346 }
1347
1348 _public_ void* sd_resolve_query_set_userdata(sd_resolve_query *q, void *userdata) {
1349         void *ret;
1350
1351         assert_return(q, NULL);
1352         assert_return(!resolve_pid_changed(q->resolve), NULL);
1353
1354         ret = q->userdata;
1355         q->userdata = userdata;
1356
1357         return ret;
1358 }
1359
1360 _public_ void* sd_resolve_query_get_userdata(sd_resolve_query *q) {
1361         assert_return(q, NULL);
1362         assert_return(!resolve_pid_changed(q->resolve), NULL);
1363
1364         return q->userdata;
1365 }
1366
1367 _public_ sd_resolve *sd_resolve_query_get_resolve(sd_resolve_query *q) {
1368         assert_return(q, NULL);
1369         assert_return(!resolve_pid_changed(q->resolve), NULL);
1370
1371         return q->resolve;
1372 }
1373
1374 static int io_callback(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
1375         sd_resolve *resolve = userdata;
1376         int r;
1377
1378         assert(resolve);
1379
1380         r = sd_resolve_process(resolve);
1381         if (r < 0)
1382                 return r;
1383
1384         return 1;
1385 }
1386
1387 _public_ int sd_resolve_attach_event(sd_resolve *resolve, sd_event *event, int priority) {
1388         int r;
1389
1390         assert_return(resolve, -EINVAL);
1391         assert_return(!resolve->event, -EBUSY);
1392
1393         assert(!resolve->event_source);
1394
1395         if (event)
1396                 resolve->event = sd_event_ref(event);
1397         else {
1398                 r = sd_event_default(&resolve->event);
1399                 if (r < 0)
1400                         return r;
1401         }
1402
1403         r = sd_event_add_io(resolve->event, &resolve->event_source, resolve->fds[RESPONSE_RECV_FD], POLLIN, io_callback, resolve);
1404         if (r < 0)
1405                 goto fail;
1406
1407         r = sd_event_source_set_priority(resolve->event_source, priority);
1408         if (r < 0)
1409                 goto fail;
1410
1411         return 0;
1412
1413 fail:
1414         sd_resolve_detach_event(resolve);
1415         return r;
1416 }
1417
1418 _public_  int sd_resolve_detach_event(sd_resolve *resolve) {
1419         assert_return(resolve, -EINVAL);
1420
1421         if (!resolve->event)
1422                 return 0;
1423
1424         if (resolve->event_source) {
1425                 sd_event_source_set_enabled(resolve->event_source, SD_EVENT_OFF);
1426                 resolve->event_source = sd_event_source_unref(resolve->event_source);
1427         }
1428
1429         resolve->event = sd_event_unref(resolve->event);
1430         return 1;
1431 }
1432
1433 _public_ sd_event *sd_resolve_get_event(sd_resolve *resolve) {
1434         assert_return(resolve, NULL);
1435
1436         return resolve->event;
1437 }