chiark / gitweb /
sd-resolve: fix allocation if query ids, never reuse them
[elogind.git] / src / libsystemd / sd-resolve / sd-resolve.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2005-2008 Lennart Poettering
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <assert.h>
23 #include <fcntl.h>
24 #include <signal.h>
25 #include <unistd.h>
26 #include <sys/select.h>
27 #include <stdio.h>
28 #include <string.h>
29 #include <stdlib.h>
30 #include <errno.h>
31 #include <sys/wait.h>
32 #include <sys/types.h>
33 #include <pwd.h>
34 #include <netinet/in.h>
35 #include <arpa/nameser.h>
36 #include <resolv.h>
37 #include <dirent.h>
38 #include <sys/time.h>
39 #include <sys/resource.h>
40 #include <stdint.h>
41 #include <pthread.h>
42 #include <sys/prctl.h>
43 #include <sys/poll.h>
44
45 #include "util.h"
46 #include "list.h"
47 #include "socket-util.h"
48 #include "missing.h"
49 #include "resolve-util.h"
50 #include "sd-resolve.h"
51
52 #define WORKERS_MIN 1U
53 #define WORKERS_MAX 16U
54 #define QUERIES_MAX 256U
55 #define BUFSIZE 10240U
56
57 typedef enum {
58         REQUEST_ADDRINFO,
59         RESPONSE_ADDRINFO,
60         REQUEST_NAMEINFO,
61         RESPONSE_NAMEINFO,
62         REQUEST_RES_QUERY,
63         REQUEST_RES_SEARCH,
64         RESPONSE_RES,
65         REQUEST_TERMINATE,
66         RESPONSE_DIED
67 } QueryType;
68
69 enum {
70         REQUEST_RECV_FD,
71         REQUEST_SEND_FD,
72         RESPONSE_RECV_FD,
73         RESPONSE_SEND_FD,
74         _FD_MAX
75 };
76
77 struct sd_resolve {
78         unsigned n_ref;
79
80         bool dead:1;
81         pid_t original_pid;
82
83         int fds[_FD_MAX];
84
85         pthread_t workers[WORKERS_MAX];
86         unsigned n_valid_workers;
87
88         unsigned current_id;
89         sd_resolve_query* query_array[QUERIES_MAX];
90         unsigned n_queries, n_done;
91
92         sd_event_source *event_source;
93         sd_event *event;
94
95         sd_resolve_query *current;
96
97         sd_resolve **default_resolve_ptr;
98         pid_t tid;
99
100         LIST_HEAD(sd_resolve_query, queries);
101 };
102
103 struct sd_resolve_query {
104         unsigned n_ref;
105
106         sd_resolve *resolve;
107
108         QueryType type:4;
109         bool done:1;
110         bool floating:1;
111         unsigned id;
112
113         int ret;
114         int _errno;
115         int _h_errno;
116         struct addrinfo *addrinfo;
117         char *serv, *host;
118         unsigned char *answer;
119
120         union {
121                 sd_resolve_getaddrinfo_handler_t getaddrinfo_handler;
122                 sd_resolve_getnameinfo_handler_t getnameinfo_handler;
123                 sd_resolve_res_handler_t res_handler;
124         };
125
126         void *userdata;
127
128         LIST_FIELDS(sd_resolve_query, queries);
129 };
130
131 typedef struct RHeader {
132         QueryType type;
133         unsigned id;
134         size_t length;
135 } RHeader;
136
137 typedef struct AddrInfoRequest {
138         struct RHeader header;
139         bool hints_valid;
140         int ai_flags;
141         int ai_family;
142         int ai_socktype;
143         int ai_protocol;
144         size_t node_len, service_len;
145 } AddrInfoRequest;
146
147 typedef struct AddrInfoResponse {
148         struct RHeader header;
149         int ret;
150         int _errno;
151         int _h_errno;
152         /* followed by addrinfo_serialization[] */
153 } AddrInfoResponse;
154
155 typedef struct AddrInfoSerialization {
156         int ai_flags;
157         int ai_family;
158         int ai_socktype;
159         int ai_protocol;
160         size_t ai_addrlen;
161         size_t canonname_len;
162         /* Followed by ai_addr amd ai_canonname with variable lengths */
163 } AddrInfoSerialization;
164
165 typedef struct NameInfoRequest {
166         struct RHeader header;
167         int flags;
168         socklen_t sockaddr_len;
169         bool gethost:1, getserv:1;
170 } NameInfoRequest;
171
172 typedef struct NameInfoResponse {
173         struct RHeader header;
174         size_t hostlen, servlen;
175         int ret;
176         int _errno;
177         int _h_errno;
178 } NameInfoResponse;
179
180 typedef struct ResRequest {
181         struct RHeader header;
182         int class;
183         int type;
184         size_t dname_len;
185 } ResRequest;
186
187 typedef struct ResResponse {
188         struct RHeader header;
189         int ret;
190         int _errno;
191         int _h_errno;
192 } ResResponse;
193
194 typedef union Packet {
195         RHeader rheader;
196         AddrInfoRequest addrinfo_request;
197         AddrInfoResponse addrinfo_response;
198         NameInfoRequest nameinfo_request;
199         NameInfoResponse nameinfo_response;
200         ResRequest res_request;
201         ResResponse res_response;
202 } Packet;
203
204 static int getaddrinfo_done(sd_resolve_query* q);
205 static int getnameinfo_done(sd_resolve_query *q);
206 static int res_query_done(sd_resolve_query* q);
207
208 static void resolve_query_disconnect(sd_resolve_query *q);
209
210 #define RESOLVE_DONT_DESTROY(resolve) \
211         _cleanup_resolve_unref_ _unused_ sd_resolve *_dont_destroy_##resolve = sd_resolve_ref(resolve)
212
213 static int send_died(int out_fd) {
214
215         RHeader rh = {
216                 .type = RESPONSE_DIED,
217                 .length = sizeof(RHeader),
218         };
219
220         assert(out_fd >= 0);
221
222         if (send(out_fd, &rh, rh.length, MSG_NOSIGNAL) < 0)
223                 return -errno;
224
225         return 0;
226 }
227
228 static void *serialize_addrinfo(void *p, const struct addrinfo *ai, size_t *length, size_t maxlength) {
229         AddrInfoSerialization s;
230         size_t cnl, l;
231
232         assert(p);
233         assert(ai);
234         assert(length);
235         assert(*length <= maxlength);
236
237         cnl = ai->ai_canonname ? strlen(ai->ai_canonname)+1 : 0;
238         l = sizeof(AddrInfoSerialization) + ai->ai_addrlen + cnl;
239
240         if (*length + l > maxlength)
241                 return NULL;
242
243         s.ai_flags = ai->ai_flags;
244         s.ai_family = ai->ai_family;
245         s.ai_socktype = ai->ai_socktype;
246         s.ai_protocol = ai->ai_protocol;
247         s.ai_addrlen = ai->ai_addrlen;
248         s.canonname_len = cnl;
249
250         memcpy((uint8_t*) p, &s, sizeof(AddrInfoSerialization));
251         memcpy((uint8_t*) p + sizeof(AddrInfoSerialization), ai->ai_addr, ai->ai_addrlen);
252
253         if (ai->ai_canonname)
254                 memcpy((char*) p + sizeof(AddrInfoSerialization) + ai->ai_addrlen, ai->ai_canonname, cnl);
255
256         *length += l;
257         return (uint8_t*) p + l;
258 }
259
260 static int send_addrinfo_reply(
261                 int out_fd,
262                 unsigned id,
263                 int ret,
264                 struct addrinfo *ai,
265                 int _errno,
266                 int _h_errno) {
267
268         AddrInfoResponse resp = {
269                 .header.type = RESPONSE_ADDRINFO,
270                 .header.id = id,
271                 .header.length = sizeof(AddrInfoResponse),
272                 .ret = ret,
273                 ._errno = _errno,
274                 ._h_errno = _h_errno,
275         };
276
277         struct msghdr mh = {};
278         struct iovec iov[2];
279         union {
280                 AddrInfoSerialization ais;
281                 uint8_t space[BUFSIZE];
282         } buffer;
283
284         assert(out_fd >= 0);
285
286         if (ret == 0 && ai) {
287                 void *p = &buffer;
288                 struct addrinfo *k;
289
290                 for (k = ai; k; k = k->ai_next) {
291                         p = serialize_addrinfo(p, k, &resp.header.length, (uint8_t*) &buffer + BUFSIZE - (uint8_t*) p);
292                         if (!p) {
293                                 freeaddrinfo(ai);
294                                 return -ENOBUFS;
295                         }
296                 }
297         }
298
299         if (ai)
300                 freeaddrinfo(ai);
301
302         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(AddrInfoResponse) };
303         iov[1] = (struct iovec) { .iov_base = &buffer, .iov_len = resp.header.length - sizeof(AddrInfoResponse) };
304
305         mh.msg_iov = iov;
306         mh.msg_iovlen = ELEMENTSOF(iov);
307
308         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
309                 return -errno;
310
311         return 0;
312 }
313
314 static int send_nameinfo_reply(
315                 int out_fd,
316                 unsigned id,
317                 int ret,
318                 const char *host,
319                 const char *serv,
320                 int _errno,
321                 int _h_errno) {
322
323         NameInfoResponse resp = {
324                 .header.type = RESPONSE_NAMEINFO,
325                 .header.id = id,
326                 .ret = ret,
327                 ._errno = _errno,
328                 ._h_errno = _h_errno,
329         };
330
331         struct msghdr mh = {};
332         struct iovec iov[3];
333         size_t hl, sl;
334
335         assert(out_fd >= 0);
336
337         sl = serv ? strlen(serv)+1 : 0;
338         hl = host ? strlen(host)+1 : 0;
339
340         resp.header.length = sizeof(NameInfoResponse) + hl + sl;
341         resp.hostlen = hl;
342         resp.servlen = sl;
343
344         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(NameInfoResponse) };
345         iov[1] = (struct iovec) { .iov_base = (void*) host, .iov_len = hl };
346         iov[2] = (struct iovec) { .iov_base = (void*) serv, .iov_len = sl };
347
348         mh.msg_iov = iov;
349         mh.msg_iovlen = ELEMENTSOF(iov);
350
351         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
352                 return -errno;
353
354         return 0;
355 }
356
357 static int send_res_reply(int out_fd, unsigned id, const unsigned char *answer, int ret, int _errno, int _h_errno) {
358
359         ResResponse resp = {
360                 .header.type = RESPONSE_RES,
361                 .header.id = id,
362                 .ret = ret,
363                 ._errno = _errno,
364                 ._h_errno = _h_errno,
365         };
366
367         struct msghdr mh = {};
368         struct iovec iov[2];
369         size_t l;
370
371         assert(out_fd >= 0);
372
373         l = ret > 0 ? (size_t) ret : 0;
374
375         resp.header.length = sizeof(ResResponse) + l;
376
377         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(ResResponse) };
378         iov[1] = (struct iovec) { .iov_base = (void*) answer, .iov_len = l };
379
380         mh.msg_iov = iov;
381         mh.msg_iovlen = ELEMENTSOF(iov);
382
383         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
384                 return -errno;
385
386         return 0;
387 }
388
389 static int handle_request(int out_fd, const Packet *packet, size_t length) {
390         const RHeader *req;
391
392         assert(out_fd >= 0);
393         assert(packet);
394
395         req = &packet->rheader;
396
397         assert(length >= sizeof(RHeader));
398         assert(length == req->length);
399
400         switch (req->type) {
401
402         case REQUEST_ADDRINFO: {
403                const AddrInfoRequest *ai_req = &packet->addrinfo_request;
404                struct addrinfo hints = {}, *result = NULL;
405                const char *node, *service;
406                int ret;
407
408                assert(length >= sizeof(AddrInfoRequest));
409                assert(length == sizeof(AddrInfoRequest) + ai_req->node_len + ai_req->service_len);
410
411                hints.ai_flags = ai_req->ai_flags;
412                hints.ai_family = ai_req->ai_family;
413                hints.ai_socktype = ai_req->ai_socktype;
414                hints.ai_protocol = ai_req->ai_protocol;
415
416                node = ai_req->node_len ? (const char*) ai_req + sizeof(AddrInfoRequest) : NULL;
417                service = ai_req->service_len ? (const char*) ai_req + sizeof(AddrInfoRequest) + ai_req->node_len : NULL;
418
419                ret = getaddrinfo(
420                                node, service,
421                                ai_req->hints_valid ? &hints : NULL,
422                                &result);
423
424                /* send_addrinfo_reply() frees result */
425                return send_addrinfo_reply(out_fd, req->id, ret, result, errno, h_errno);
426         }
427
428         case REQUEST_NAMEINFO: {
429                const NameInfoRequest *ni_req = &packet->nameinfo_request;
430                char hostbuf[NI_MAXHOST], servbuf[NI_MAXSERV];
431                union sockaddr_union sa;
432                int ret;
433
434                assert(length >= sizeof(NameInfoRequest));
435                assert(length == sizeof(NameInfoRequest) + ni_req->sockaddr_len);
436                assert(sizeof(sa) >= ni_req->sockaddr_len);
437
438                memcpy(&sa, (const uint8_t *) ni_req + sizeof(NameInfoRequest), ni_req->sockaddr_len);
439
440                ret = getnameinfo(&sa.sa, ni_req->sockaddr_len,
441                                ni_req->gethost ? hostbuf : NULL, ni_req->gethost ? sizeof(hostbuf) : 0,
442                                ni_req->getserv ? servbuf : NULL, ni_req->getserv ? sizeof(servbuf) : 0,
443                                ni_req->flags);
444
445                return send_nameinfo_reply(out_fd, req->id, ret,
446                                ret == 0 && ni_req->gethost ? hostbuf : NULL,
447                                ret == 0 && ni_req->getserv ? servbuf : NULL,
448                                errno, h_errno);
449         }
450
451         case REQUEST_RES_QUERY:
452         case REQUEST_RES_SEARCH: {
453                  const ResRequest *res_req = &packet->res_request;
454                  union {
455                          HEADER header;
456                          uint8_t space[BUFSIZE];
457                  } answer;
458                  const char *dname;
459                  int ret;
460
461                  assert(length >= sizeof(ResRequest));
462                  assert(length == sizeof(ResRequest) + res_req->dname_len);
463
464                  dname = (const char *) req + sizeof(ResRequest);
465
466                  if (req->type == REQUEST_RES_QUERY)
467                          ret = res_query(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
468                  else
469                          ret = res_search(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
470
471                  return send_res_reply(out_fd, req->id, (unsigned char *) &answer, ret, errno, h_errno);
472         }
473
474         case REQUEST_TERMINATE:
475                  /* Quit */
476                  return -ECONNRESET;
477
478         default:
479                 assert_not_reached("Unknown request");
480         }
481
482         return 0;
483 }
484
485 static void* thread_worker(void *p) {
486         sd_resolve *resolve = p;
487         sigset_t fullset;
488
489         /* No signals in this thread please */
490         assert_se(sigfillset(&fullset) == 0);
491         assert_se(pthread_sigmask(SIG_BLOCK, &fullset, NULL) == 0);
492
493         /* Assign a pretty name to this thread */
494         prctl(PR_SET_NAME, (unsigned long) "sd-resolve");
495
496         while (!resolve->dead) {
497                 union {
498                         Packet packet;
499                         uint8_t space[BUFSIZE];
500                 } buf;
501                 ssize_t length;
502
503                 length = recv(resolve->fds[REQUEST_RECV_FD], &buf, sizeof(buf), 0);
504                 if (length < 0) {
505                         if (errno == EINTR)
506                                 continue;
507
508                         break;
509                 }
510                 if (length == 0)
511                         break;
512
513                 if (resolve->dead)
514                         break;
515
516                 if (handle_request(resolve->fds[RESPONSE_SEND_FD], &buf.packet, (size_t) length) < 0)
517                         break;
518         }
519
520         send_died(resolve->fds[RESPONSE_SEND_FD]);
521
522         return NULL;
523 }
524
525 static int start_threads(sd_resolve *resolve, unsigned extra) {
526         unsigned n;
527         int r;
528
529         n = resolve->n_queries + extra - resolve->n_done;
530         n = CLAMP(n, WORKERS_MIN, WORKERS_MAX);
531
532         while (resolve->n_valid_workers < n) {
533
534                 r = pthread_create(&resolve->workers[resolve->n_valid_workers], NULL, thread_worker, resolve);
535                 if (r != 0)
536                         return -r;
537
538                 resolve->n_valid_workers ++;
539         }
540
541         return 0;
542 }
543
544 static bool resolve_pid_changed(sd_resolve *r) {
545         assert(r);
546
547         /* We don't support people creating a resolver and keeping it
548          * around after fork(). Let's complain. */
549
550         return r->original_pid != getpid();
551 }
552
553 _public_ int sd_resolve_new(sd_resolve **ret) {
554         sd_resolve *resolve = NULL;
555         int i, r;
556
557         assert_return(ret, -EINVAL);
558
559         resolve = new0(sd_resolve, 1);
560         if (!resolve)
561                 return -ENOMEM;
562
563         resolve->n_ref = 1;
564         resolve->original_pid = getpid();
565
566         for (i = 0; i < _FD_MAX; i++)
567                 resolve->fds[i] = -1;
568
569         r = socketpair(PF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0, resolve->fds + REQUEST_RECV_FD);
570         if (r < 0) {
571                 r = -errno;
572                 goto fail;
573         }
574
575         r = socketpair(PF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0, resolve->fds + RESPONSE_RECV_FD);
576         if (r < 0) {
577                 r = -errno;
578                 goto fail;
579         }
580
581         fd_inc_sndbuf(resolve->fds[REQUEST_SEND_FD], QUERIES_MAX * BUFSIZE);
582         fd_inc_rcvbuf(resolve->fds[REQUEST_RECV_FD], QUERIES_MAX * BUFSIZE);
583         fd_inc_sndbuf(resolve->fds[RESPONSE_SEND_FD], QUERIES_MAX * BUFSIZE);
584         fd_inc_rcvbuf(resolve->fds[RESPONSE_RECV_FD], QUERIES_MAX * BUFSIZE);
585
586         fd_nonblock(resolve->fds[RESPONSE_RECV_FD], true);
587
588         *ret = resolve;
589         return 0;
590
591 fail:
592         sd_resolve_unref(resolve);
593         return r;
594 }
595
596 _public_ int sd_resolve_default(sd_resolve **ret) {
597
598         static thread_local sd_resolve *default_resolve = NULL;
599         sd_resolve *e = NULL;
600         int r;
601
602         if (!ret)
603                 return !!default_resolve;
604
605         if (default_resolve) {
606                 *ret = sd_resolve_ref(default_resolve);
607                 return 0;
608         }
609
610         r = sd_resolve_new(&e);
611         if (r < 0)
612                 return r;
613
614         e->default_resolve_ptr = &default_resolve;
615         e->tid = gettid();
616         default_resolve = e;
617
618         *ret = e;
619         return 1;
620 }
621
622 _public_ int sd_resolve_get_tid(sd_resolve *resolve, pid_t *tid) {
623         assert_return(resolve, -EINVAL);
624         assert_return(tid, -EINVAL);
625         assert_return(!resolve_pid_changed(resolve), -ECHILD);
626
627         if (resolve->tid != 0) {
628                 *tid = resolve->tid;
629                 return 0;
630         }
631
632         if (resolve->event)
633                 return sd_event_get_tid(resolve->event, tid);
634
635         return -ENXIO;
636 }
637
638 static void resolve_free(sd_resolve *resolve) {
639         PROTECT_ERRNO;
640         sd_resolve_query *q;
641         unsigned i;
642
643         assert(resolve);
644
645         while ((q = resolve->queries)) {
646                 assert(q->floating);
647                 resolve_query_disconnect(q);
648                 sd_resolve_query_unref(q);
649         }
650
651         if (resolve->default_resolve_ptr)
652                 *(resolve->default_resolve_ptr) = NULL;
653
654         resolve->dead = true;
655
656         sd_resolve_detach_event(resolve);
657
658         if (resolve->fds[REQUEST_SEND_FD] >= 0) {
659
660                 RHeader req = {
661                         .type = REQUEST_TERMINATE,
662                         .length = sizeof(req)
663                 };
664
665                 /* Send one termination packet for each worker */
666                 for (i = 0; i < resolve->n_valid_workers; i++)
667                         send(resolve->fds[REQUEST_SEND_FD], &req, req.length, MSG_NOSIGNAL);
668         }
669
670         /* Now terminate them and wait until they are gone. */
671         for (i = 0; i < resolve->n_valid_workers; i++) {
672                 for (;;) {
673                         if (pthread_join(resolve->workers[i], NULL) != EINTR)
674                                 break;
675                 }
676         }
677
678         /* Close all communication channels */
679         for (i = 0; i < _FD_MAX; i++)
680                 safe_close(resolve->fds[i]);
681
682         free(resolve);
683 }
684
685 _public_ sd_resolve* sd_resolve_ref(sd_resolve *resolve) {
686         assert_return(resolve, NULL);
687
688         assert(resolve->n_ref >= 1);
689         resolve->n_ref++;
690
691         return resolve;
692 }
693
694 _public_ sd_resolve* sd_resolve_unref(sd_resolve *resolve) {
695
696         if (!resolve)
697                 return NULL;
698
699         assert(resolve->n_ref >= 1);
700         resolve->n_ref--;
701
702         if (resolve->n_ref <= 0)
703                 resolve_free(resolve);
704
705         return NULL;
706 }
707
708 _public_ int sd_resolve_get_fd(sd_resolve *resolve) {
709         assert_return(resolve, -EINVAL);
710         assert_return(!resolve_pid_changed(resolve), -ECHILD);
711
712         return resolve->fds[RESPONSE_RECV_FD];
713 }
714
715 _public_ int sd_resolve_get_events(sd_resolve *resolve) {
716         assert_return(resolve, -EINVAL);
717         assert_return(!resolve_pid_changed(resolve), -ECHILD);
718
719         return resolve->n_queries > resolve->n_done ? POLLIN : 0;
720 }
721
722 _public_ int sd_resolve_get_timeout(sd_resolve *resolve, uint64_t *usec) {
723         assert_return(resolve, -EINVAL);
724         assert_return(usec, -EINVAL);
725         assert_return(!resolve_pid_changed(resolve), -ECHILD);
726
727         *usec = (uint64_t) -1;
728         return 0;
729 }
730
731 static sd_resolve_query *lookup_query(sd_resolve *resolve, unsigned id) {
732         sd_resolve_query *q;
733
734         assert(resolve);
735
736         q = resolve->query_array[id % QUERIES_MAX];
737         if (q)
738                 if (q->id == id)
739                         return q;
740
741         return NULL;
742 }
743
744 static int complete_query(sd_resolve *resolve, sd_resolve_query *q) {
745         int r;
746
747         assert(q);
748         assert(!q->done);
749         assert(q->resolve == resolve);
750
751         q->done = true;
752         resolve->n_done ++;
753
754         resolve->current = sd_resolve_query_ref(q);
755
756         switch (q->type) {
757
758         case REQUEST_ADDRINFO:
759                 r = getaddrinfo_done(q);
760                 break;
761
762         case REQUEST_NAMEINFO:
763                 r = getnameinfo_done(q);
764                 break;
765
766         case REQUEST_RES_QUERY:
767         case REQUEST_RES_SEARCH:
768                 r = res_query_done(q);
769                 break;
770
771         default:
772                 assert_not_reached("Cannot complete unknown query type");
773         }
774
775         resolve->current = NULL;
776
777         if (q->floating) {
778                 resolve_query_disconnect(q);
779                 sd_resolve_query_unref(q);
780         }
781
782         sd_resolve_query_unref(q);
783
784         return r;
785 }
786
787 static int unserialize_addrinfo(const void **p, size_t *length, struct addrinfo **ret_ai) {
788         AddrInfoSerialization s;
789         size_t l;
790         struct addrinfo *ai;
791
792         assert(p);
793         assert(*p);
794         assert(ret_ai);
795         assert(length);
796
797         if (*length < sizeof(AddrInfoSerialization))
798                 return -EBADMSG;
799
800         memcpy(&s, *p, sizeof(s));
801
802         l = sizeof(AddrInfoSerialization) + s.ai_addrlen + s.canonname_len;
803         if (*length < l)
804                 return -EBADMSG;
805
806         ai = new0(struct addrinfo, 1);
807         if (!ai)
808                 return -ENOMEM;
809
810         ai->ai_flags = s.ai_flags;
811         ai->ai_family = s.ai_family;
812         ai->ai_socktype = s.ai_socktype;
813         ai->ai_protocol = s.ai_protocol;
814         ai->ai_addrlen = s.ai_addrlen;
815
816         if (s.ai_addrlen > 0) {
817                 ai->ai_addr = memdup((const uint8_t*) *p + sizeof(AddrInfoSerialization), s.ai_addrlen);
818                 if (!ai->ai_addr) {
819                         free(ai);
820                         return -ENOMEM;
821                 }
822         }
823
824         if (s.canonname_len > 0) {
825                 ai->ai_canonname = memdup((const uint8_t*) *p + sizeof(AddrInfoSerialization) + s.ai_addrlen, s.canonname_len);
826                 if (!ai->ai_canonname) {
827                         free(ai->ai_addr);
828                         free(ai);
829                         return -ENOMEM;
830                 }
831         }
832
833         *length -= l;
834         *ret_ai = ai;
835         *p = ((const uint8_t*) *p) + l;
836
837         return 0;
838 }
839
840 static int handle_response(sd_resolve *resolve, const Packet *packet, size_t length) {
841         const RHeader *resp;
842         sd_resolve_query *q;
843         int r;
844
845         assert(resolve);
846
847         resp = &packet->rheader;
848         assert(resp);
849         assert(length >= sizeof(RHeader));
850         assert(length == resp->length);
851
852         if (resp->type == RESPONSE_DIED) {
853                 resolve->dead = true;
854                 return 0;
855         }
856
857         q = lookup_query(resolve, resp->id);
858         if (!q)
859                 return 0;
860
861         switch (resp->type) {
862
863         case RESPONSE_ADDRINFO: {
864                 const AddrInfoResponse *ai_resp = &packet->addrinfo_response;
865                 const void *p;
866                 size_t l;
867                 struct addrinfo *prev = NULL;
868
869                 assert(length >= sizeof(AddrInfoResponse));
870                 assert(q->type == REQUEST_ADDRINFO);
871
872                 q->ret = ai_resp->ret;
873                 q->_errno = ai_resp->_errno;
874                 q->_h_errno = ai_resp->_h_errno;
875
876                 l = length - sizeof(AddrInfoResponse);
877                 p = (const uint8_t*) resp + sizeof(AddrInfoResponse);
878
879                 while (l > 0 && p) {
880                         struct addrinfo *ai = NULL;
881
882                         r = unserialize_addrinfo(&p, &l, &ai);
883                         if (r < 0) {
884                                 q->ret = EAI_SYSTEM;
885                                 q->_errno = -r;
886                                 q->_h_errno = 0;
887                                 freeaddrinfo(q->addrinfo);
888                                 q->addrinfo = NULL;
889                                 break;
890                         }
891
892                         if (prev)
893                                 prev->ai_next = ai;
894                         else
895                                 q->addrinfo = ai;
896
897                         prev = ai;
898                 }
899
900                 return complete_query(resolve, q);
901         }
902
903         case RESPONSE_NAMEINFO: {
904                 const NameInfoResponse *ni_resp = &packet->nameinfo_response;
905
906                 assert(length >= sizeof(NameInfoResponse));
907                 assert(q->type == REQUEST_NAMEINFO);
908
909                 q->ret = ni_resp->ret;
910                 q->_errno = ni_resp->_errno;
911                 q->_h_errno = ni_resp->_h_errno;
912
913                 if (ni_resp->hostlen > 0) {
914                         q->host = strndup((const char*) ni_resp + sizeof(NameInfoResponse), ni_resp->hostlen-1);
915                         if (!q->host) {
916                                 q->ret = EAI_MEMORY;
917                                 q->_errno = ENOMEM;
918                                 q->_h_errno = 0;
919                         }
920                 }
921
922                 if (ni_resp->servlen > 0) {
923                         q->serv = strndup((const char*) ni_resp + sizeof(NameInfoResponse) + ni_resp->hostlen, ni_resp->servlen-1);
924                         if (!q->serv) {
925                                 q->ret = EAI_MEMORY;
926                                 q->_errno = ENOMEM;
927                                 q->_h_errno = 0;
928                         }
929                 }
930
931                 return complete_query(resolve, q);
932         }
933
934         case RESPONSE_RES: {
935                 const ResResponse *res_resp = &packet->res_response;
936
937                 assert(length >= sizeof(ResResponse));
938                 assert(q->type == REQUEST_RES_QUERY || q->type == REQUEST_RES_SEARCH);
939
940                 q->ret = res_resp->ret;
941                 q->_errno = res_resp->_errno;
942                 q->_h_errno = res_resp->_h_errno;
943
944                 if (res_resp->ret >= 0)  {
945                         q->answer = memdup((const char *)resp + sizeof(ResResponse), res_resp->ret);
946                         if (!q->answer) {
947                                 q->ret = -1;
948                                 q->_errno = ENOMEM;
949                                 q->_h_errno = 0;
950                         }
951                 }
952
953                 return complete_query(resolve, q);
954         }
955
956         default:
957                 return 0;
958         }
959 }
960
961 _public_ int sd_resolve_process(sd_resolve *resolve) {
962         RESOLVE_DONT_DESTROY(resolve);
963
964         union {
965                 Packet packet;
966                 uint8_t space[BUFSIZE];
967         } buf;
968         ssize_t l;
969         int r;
970
971         assert_return(resolve, -EINVAL);
972         assert_return(!resolve_pid_changed(resolve), -ECHILD);
973
974         /* We don't allow recursively invoking sd_resolve_process(). */
975         assert_return(!resolve->current, -EBUSY);
976
977         l = recv(resolve->fds[RESPONSE_RECV_FD], &buf, sizeof(buf), 0);
978         if (l < 0) {
979                 if (errno == EAGAIN)
980                         return 0;
981
982                 return -errno;
983         }
984         if (l == 0)
985                 return -ECONNREFUSED;
986
987         r = handle_response(resolve, &buf.packet, (size_t) l);
988         if (r < 0)
989                 return r;
990
991         return 1;
992 }
993
994 _public_ int sd_resolve_wait(sd_resolve *resolve, uint64_t timeout_usec) {
995         int r;
996
997         assert_return(resolve, -EINVAL);
998         assert_return(!resolve_pid_changed(resolve), -ECHILD);
999
1000         if (resolve->n_done >= resolve->n_queries)
1001                 return 0;
1002
1003         do {
1004                 r = fd_wait_for_event(resolve->fds[RESPONSE_RECV_FD], POLLIN, timeout_usec);
1005         } while (r == -EINTR);
1006
1007         if (r < 0)
1008                 return r;
1009
1010         return sd_resolve_process(resolve);
1011 }
1012
1013 static int alloc_query(sd_resolve *resolve, bool floating, sd_resolve_query **_q) {
1014         sd_resolve_query *q;
1015         int r;
1016
1017         assert(resolve);
1018         assert(_q);
1019
1020         if (resolve->n_queries >= QUERIES_MAX)
1021                 return -ENOBUFS;
1022
1023         r = start_threads(resolve, 1);
1024         if (r < 0)
1025                 return r;
1026
1027         while (resolve->query_array[resolve->current_id % QUERIES_MAX])
1028                 resolve->current_id++;
1029
1030         q = resolve->query_array[resolve->current_id % QUERIES_MAX] = new0(sd_resolve_query, 1);
1031         if (!q)
1032                 return -ENOMEM;
1033
1034         q->n_ref = 1;
1035         q->resolve = resolve;
1036         q->floating = floating;
1037         q->id = resolve->current_id++;
1038
1039         if (!floating)
1040                 sd_resolve_ref(resolve);
1041
1042         LIST_PREPEND(queries, resolve->queries, q);
1043         resolve->n_queries++;
1044
1045         *_q = q;
1046         return 0;
1047 }
1048
1049 _public_ int sd_resolve_getaddrinfo(
1050                 sd_resolve *resolve,
1051                 sd_resolve_query **_q,
1052                 const char *node, const char *service,
1053                 const struct addrinfo *hints,
1054                 sd_resolve_getaddrinfo_handler_t callback, void *userdata) {
1055
1056         AddrInfoRequest req = {};
1057         struct msghdr mh = {};
1058         struct iovec iov[3];
1059         sd_resolve_query *q;
1060         int r;
1061
1062         assert_return(resolve, -EINVAL);
1063         assert_return(node || service, -EINVAL);
1064         assert_return(callback, -EINVAL);
1065         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1066
1067         r = alloc_query(resolve, !_q, &q);
1068         if (r < 0)
1069                 return r;
1070
1071         q->type = REQUEST_ADDRINFO;
1072         q->getaddrinfo_handler = callback;
1073         q->userdata = userdata;
1074
1075         req.node_len = node ? strlen(node)+1 : 0;
1076         req.service_len = service ? strlen(service)+1 : 0;
1077
1078         req.header.id = q->id;
1079         req.header.type = REQUEST_ADDRINFO;
1080         req.header.length = sizeof(AddrInfoRequest) + req.node_len + req.service_len;
1081
1082         if (hints) {
1083                 req.hints_valid = true;
1084                 req.ai_flags = hints->ai_flags;
1085                 req.ai_family = hints->ai_family;
1086                 req.ai_socktype = hints->ai_socktype;
1087                 req.ai_protocol = hints->ai_protocol;
1088         }
1089
1090         iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(AddrInfoRequest) };
1091         if (node)
1092                 iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = (void*) node, .iov_len = req.node_len };
1093         if (service)
1094                 iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = (void*) service, .iov_len = req.service_len };
1095         mh.msg_iov = iov;
1096
1097         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1098                 sd_resolve_query_unref(q);
1099                 return -errno;
1100         }
1101
1102         if (_q)
1103                 *_q = q;
1104
1105         return 0;
1106 }
1107
1108 static int getaddrinfo_done(sd_resolve_query* q) {
1109         assert(q);
1110         assert(q->done);
1111         assert(q->getaddrinfo_handler);
1112
1113         errno = q->_errno;
1114         h_errno = q->_h_errno;
1115
1116         return q->getaddrinfo_handler(q, q->ret, q->addrinfo, q->userdata);
1117 }
1118
1119 _public_ int sd_resolve_getnameinfo(
1120                 sd_resolve *resolve,
1121                 sd_resolve_query**_q,
1122                 const struct sockaddr *sa, socklen_t salen,
1123                 int flags,
1124                 uint64_t get,
1125                 sd_resolve_getnameinfo_handler_t callback,
1126                 void *userdata) {
1127
1128         NameInfoRequest req = {};
1129         struct msghdr mh = {};
1130         struct iovec iov[2];
1131         sd_resolve_query *q;
1132         int r;
1133
1134         assert_return(resolve, -EINVAL);
1135         assert_return(sa, -EINVAL);
1136         assert_return(salen >= sizeof(struct sockaddr), -EINVAL);
1137         assert_return(salen <= sizeof(union sockaddr_union), -EINVAL);
1138         assert_return((get & ~SD_RESOLVE_GET_BOTH) == 0, -EINVAL);
1139         assert_return(callback, -EINVAL);
1140         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1141
1142         r = alloc_query(resolve, !_q, &q);
1143         if (r < 0)
1144                 return r;
1145
1146         q->type = REQUEST_NAMEINFO;
1147         q->getnameinfo_handler = callback;
1148         q->userdata = userdata;
1149
1150         req.header.id = q->id;
1151         req.header.type = REQUEST_NAMEINFO;
1152         req.header.length = sizeof(NameInfoRequest) + salen;
1153
1154         req.flags = flags;
1155         req.sockaddr_len = salen;
1156         req.gethost = !!(get & SD_RESOLVE_GET_HOST);
1157         req.getserv = !!(get & SD_RESOLVE_GET_SERVICE);
1158
1159         iov[0] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(NameInfoRequest) };
1160         iov[1] = (struct iovec) { .iov_base = (void*) sa, .iov_len = salen };
1161
1162         mh.msg_iov = iov;
1163         mh.msg_iovlen = 2;
1164
1165         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1166                 sd_resolve_query_unref(q);
1167                 return -errno;
1168         }
1169
1170         if (_q)
1171                 *_q = q;
1172
1173         return 0;
1174 }
1175
1176 static int getnameinfo_done(sd_resolve_query *q) {
1177
1178         assert(q);
1179         assert(q->done);
1180         assert(q->getnameinfo_handler);
1181
1182         errno = q->_errno;
1183         h_errno= q->_h_errno;
1184
1185         return q->getnameinfo_handler(q, q->ret, q->host, q->serv, q->userdata);
1186 }
1187
1188 static int resolve_res(
1189                 sd_resolve *resolve,
1190                 sd_resolve_query **_q,
1191                 QueryType qtype,
1192                 const char *dname,
1193                 int class, int type,
1194                 sd_resolve_res_handler_t callback, void *userdata) {
1195
1196         struct msghdr mh = {};
1197         struct iovec iov[2];
1198         ResRequest req = {};
1199         sd_resolve_query *q;
1200         int r;
1201
1202         assert_return(resolve, -EINVAL);
1203         assert_return(dname, -EINVAL);
1204         assert_return(callback, -EINVAL);
1205         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1206
1207         r = alloc_query(resolve, !_q, &q);
1208         if (r < 0)
1209                 return r;
1210
1211         q->type = qtype;
1212         q->res_handler = callback;
1213         q->userdata = userdata;
1214
1215         req.dname_len = strlen(dname) + 1;
1216         req.class = class;
1217         req.type = type;
1218
1219         req.header.id = q->id;
1220         req.header.type = qtype;
1221         req.header.length = sizeof(ResRequest) + req.dname_len;
1222
1223         iov[0] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(ResRequest) };
1224         iov[1] = (struct iovec) { .iov_base = (void*) dname, .iov_len = req.dname_len };
1225
1226         mh.msg_iov = iov;
1227         mh.msg_iovlen = 2;
1228
1229         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1230                 sd_resolve_query_unref(q);
1231                 return -errno;
1232         }
1233
1234         if (_q)
1235                 *_q = q;
1236
1237         return 0;
1238 }
1239
1240 _public_ int sd_resolve_res_query(sd_resolve *resolve, sd_resolve_query** q, const char *dname, int class, int type, sd_resolve_res_handler_t callback, void *userdata) {
1241         return resolve_res(resolve, q, REQUEST_RES_QUERY, dname, class, type, callback, userdata);
1242 }
1243
1244 _public_ int sd_resolve_res_search(sd_resolve *resolve, sd_resolve_query** q, const char *dname, int class, int type, sd_resolve_res_handler_t callback, void *userdata) {
1245         return resolve_res(resolve, q, REQUEST_RES_SEARCH, dname, class, type, callback, userdata);
1246 }
1247
1248 static int res_query_done(sd_resolve_query* q) {
1249         assert(q);
1250         assert(q->done);
1251         assert(q->res_handler);
1252
1253         errno = q->_errno;
1254         h_errno = q->_h_errno;
1255
1256         return q->res_handler(q, q->ret, q->answer, q->userdata);
1257 }
1258
1259 _public_ sd_resolve_query* sd_resolve_query_ref(sd_resolve_query *q) {
1260         assert_return(q, NULL);
1261
1262         assert(q->n_ref >= 1);
1263         q->n_ref++;
1264
1265         return q;
1266 }
1267
1268 static void resolve_freeaddrinfo(struct addrinfo *ai) {
1269         while (ai) {
1270                 struct addrinfo *next = ai->ai_next;
1271
1272                 free(ai->ai_addr);
1273                 free(ai->ai_canonname);
1274                 free(ai);
1275                 ai = next;
1276         }
1277 }
1278
1279 static void resolve_query_disconnect(sd_resolve_query *q) {
1280         sd_resolve *resolve;
1281         unsigned i;
1282
1283         assert(q);
1284
1285         if (!q->resolve)
1286                 return;
1287
1288         resolve = q->resolve;
1289         assert(resolve->n_queries > 0);
1290
1291         if (q->done) {
1292                 assert(resolve->n_done > 0);
1293                 resolve->n_done--;
1294         }
1295
1296         i = q->id % QUERIES_MAX;
1297         assert(resolve->query_array[i] == q);
1298         resolve->query_array[i] = NULL;
1299         LIST_REMOVE(queries, resolve->queries, q);
1300         resolve->n_queries--;
1301
1302         q->resolve = NULL;
1303         if (!q->floating)
1304                 sd_resolve_unref(resolve);
1305 }
1306
1307 static void resolve_query_free(sd_resolve_query *q) {
1308         assert(q);
1309
1310         resolve_query_disconnect(q);
1311
1312         resolve_freeaddrinfo(q->addrinfo);
1313         free(q->host);
1314         free(q->serv);
1315         free(q->answer);
1316         free(q);
1317 }
1318
1319 _public_ sd_resolve_query* sd_resolve_query_unref(sd_resolve_query* q) {
1320         if (!q)
1321                 return NULL;
1322
1323         assert(q->n_ref >= 1);
1324         q->n_ref--;
1325
1326         if (q->n_ref <= 0)
1327                 resolve_query_free(q);
1328
1329         return NULL;
1330 }
1331
1332 _public_ int sd_resolve_query_is_done(sd_resolve_query *q) {
1333         assert_return(q, -EINVAL);
1334         assert_return(!resolve_pid_changed(q->resolve), -ECHILD);
1335
1336         return q->done;
1337 }
1338
1339 _public_ void* sd_resolve_query_set_userdata(sd_resolve_query *q, void *userdata) {
1340         void *ret;
1341
1342         assert_return(q, NULL);
1343         assert_return(!resolve_pid_changed(q->resolve), NULL);
1344
1345         ret = q->userdata;
1346         q->userdata = userdata;
1347
1348         return ret;
1349 }
1350
1351 _public_ void* sd_resolve_query_get_userdata(sd_resolve_query *q) {
1352         assert_return(q, NULL);
1353         assert_return(!resolve_pid_changed(q->resolve), NULL);
1354
1355         return q->userdata;
1356 }
1357
1358 _public_ sd_resolve *sd_resolve_query_get_resolve(sd_resolve_query *q) {
1359         assert_return(q, NULL);
1360         assert_return(!resolve_pid_changed(q->resolve), NULL);
1361
1362         return q->resolve;
1363 }
1364
1365 static int io_callback(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
1366         sd_resolve *resolve = userdata;
1367         int r;
1368
1369         assert(resolve);
1370
1371         r = sd_resolve_process(resolve);
1372         if (r < 0)
1373                 return r;
1374
1375         return 1;
1376 }
1377
1378 _public_ int sd_resolve_attach_event(sd_resolve *resolve, sd_event *event, int priority) {
1379         int r;
1380
1381         assert_return(resolve, -EINVAL);
1382         assert_return(!resolve->event, -EBUSY);
1383
1384         assert(!resolve->event_source);
1385
1386         if (event)
1387                 resolve->event = sd_event_ref(event);
1388         else {
1389                 r = sd_event_default(&resolve->event);
1390                 if (r < 0)
1391                         return r;
1392         }
1393
1394         r = sd_event_add_io(resolve->event, &resolve->event_source, resolve->fds[RESPONSE_RECV_FD], POLLIN, io_callback, resolve);
1395         if (r < 0)
1396                 goto fail;
1397
1398         r = sd_event_source_set_priority(resolve->event_source, priority);
1399         if (r < 0)
1400                 goto fail;
1401
1402         return 0;
1403
1404 fail:
1405         sd_resolve_detach_event(resolve);
1406         return r;
1407 }
1408
1409 _public_  int sd_resolve_detach_event(sd_resolve *resolve) {
1410         assert_return(resolve, -EINVAL);
1411
1412         if (!resolve->event)
1413                 return 0;
1414
1415         if (resolve->event_source) {
1416                 sd_event_source_set_enabled(resolve->event_source, SD_EVENT_OFF);
1417                 resolve->event_source = sd_event_source_unref(resolve->event_source);
1418         }
1419
1420         resolve->event = sd_event_unref(resolve->event);
1421         return 1;
1422 }
1423
1424 _public_ sd_event *sd_resolve_get_event(sd_resolve *resolve) {
1425         assert_return(resolve, NULL);
1426
1427         return resolve->event;
1428 }