chiark / gitweb /
cb8e34e3688e85ba53df331d9b7d329facd7f34b
[elogind.git] / src / libsystemd / sd-resolve / sd-resolve.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2005-2008 Lennart Poettering
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <assert.h>
23 #include <fcntl.h>
24 #include <signal.h>
25 #include <unistd.h>
26 #include <sys/select.h>
27 #include <stdio.h>
28 #include <string.h>
29 #include <stdlib.h>
30 #include <errno.h>
31 #include <sys/wait.h>
32 #include <sys/types.h>
33 #include <pwd.h>
34 #include <netinet/in.h>
35 #include <arpa/nameser.h>
36 #include <resolv.h>
37 #include <dirent.h>
38 #include <sys/time.h>
39 #include <sys/resource.h>
40 #include <stdint.h>
41 #include <pthread.h>
42 #include <sys/prctl.h>
43 #include <sys/poll.h>
44
45 #include "util.h"
46 #include "list.h"
47 #include "socket-util.h"
48 #include "missing.h"
49 #include "resolve-util.h"
50 #include "sd-resolve.h"
51
52 #define WORKERS_MIN 1U
53 #define WORKERS_MAX 16U
54 #define QUERIES_MAX 256U
55 #define BUFSIZE 10240U
56
57 typedef enum {
58         REQUEST_ADDRINFO,
59         RESPONSE_ADDRINFO,
60         REQUEST_NAMEINFO,
61         RESPONSE_NAMEINFO,
62         REQUEST_RES_QUERY,
63         REQUEST_RES_SEARCH,
64         RESPONSE_RES,
65         REQUEST_TERMINATE,
66         RESPONSE_DIED
67 } QueryType;
68
69 enum {
70         REQUEST_RECV_FD,
71         REQUEST_SEND_FD,
72         RESPONSE_RECV_FD,
73         RESPONSE_SEND_FD,
74         _FD_MAX
75 };
76
77 struct sd_resolve {
78         unsigned n_ref;
79
80         bool dead:1;
81         pid_t original_pid;
82
83         int fds[_FD_MAX];
84
85         pthread_t workers[WORKERS_MAX];
86         unsigned n_valid_workers;
87
88         unsigned current_id, current_index;
89         sd_resolve_query* queries[QUERIES_MAX];
90         unsigned n_queries, n_done;
91
92         sd_event_source *event_source;
93         sd_event *event;
94
95         sd_resolve_query *current;
96
97         sd_resolve **default_resolve_ptr;
98         pid_t tid;
99 };
100
101 struct sd_resolve_query {
102         unsigned n_ref;
103
104         sd_resolve *resolve;
105
106         QueryType type:4;
107         bool done:1;
108         unsigned id;
109
110         int ret;
111         int _errno;
112         int _h_errno;
113         struct addrinfo *addrinfo;
114         char *serv, *host;
115         unsigned char *answer;
116
117         union {
118                 sd_resolve_getaddrinfo_handler_t getaddrinfo_handler;
119                 sd_resolve_getnameinfo_handler_t getnameinfo_handler;
120                 sd_resolve_res_handler_t res_handler;
121         };
122
123         void *userdata;
124 };
125
126 typedef struct RHeader {
127         QueryType type;
128         unsigned id;
129         size_t length;
130 } RHeader;
131
132 typedef struct AddrInfoRequest {
133         struct RHeader header;
134         bool hints_valid;
135         int ai_flags;
136         int ai_family;
137         int ai_socktype;
138         int ai_protocol;
139         size_t node_len, service_len;
140 } AddrInfoRequest;
141
142 typedef struct AddrInfoResponse {
143         struct RHeader header;
144         int ret;
145         int _errno;
146         int _h_errno;
147         /* followed by addrinfo_serialization[] */
148 } AddrInfoResponse;
149
150 typedef struct AddrInfoSerialization {
151         int ai_flags;
152         int ai_family;
153         int ai_socktype;
154         int ai_protocol;
155         size_t ai_addrlen;
156         size_t canonname_len;
157         /* Followed by ai_addr amd ai_canonname with variable lengths */
158 } AddrInfoSerialization;
159
160 typedef struct NameInfoRequest {
161         struct RHeader header;
162         int flags;
163         socklen_t sockaddr_len;
164         bool gethost:1, getserv:1;
165 } NameInfoRequest;
166
167 typedef struct NameInfoResponse {
168         struct RHeader header;
169         size_t hostlen, servlen;
170         int ret;
171         int _errno;
172         int _h_errno;
173 } NameInfoResponse;
174
175 typedef struct ResRequest {
176         struct RHeader header;
177         int class;
178         int type;
179         size_t dname_len;
180 } ResRequest;
181
182 typedef struct ResResponse {
183         struct RHeader header;
184         int ret;
185         int _errno;
186         int _h_errno;
187 } ResResponse;
188
189 typedef union Packet {
190         RHeader rheader;
191         AddrInfoRequest addrinfo_request;
192         AddrInfoResponse addrinfo_response;
193         NameInfoRequest nameinfo_request;
194         NameInfoResponse nameinfo_response;
195         ResRequest res_request;
196         ResResponse res_response;
197 } Packet;
198
199 static int getaddrinfo_done(sd_resolve_query* q);
200 static int getnameinfo_done(sd_resolve_query *q);
201 static int res_query_done(sd_resolve_query* q);
202
203 #define RESOLVE_DONT_DESTROY(resolve) \
204         _cleanup_resolve_unref_ _unused_ sd_resolve *_dont_destroy_##resolve = sd_resolve_ref(resolve)
205
206 static int send_died(int out_fd) {
207
208         RHeader rh = {
209                 .type = RESPONSE_DIED,
210                 .length = sizeof(RHeader),
211         };
212
213         assert(out_fd >= 0);
214
215         if (send(out_fd, &rh, rh.length, MSG_NOSIGNAL) < 0)
216                 return -errno;
217
218         return 0;
219 }
220
221 static void *serialize_addrinfo(void *p, const struct addrinfo *ai, size_t *length, size_t maxlength) {
222         AddrInfoSerialization s;
223         size_t cnl, l;
224
225         assert(p);
226         assert(ai);
227         assert(length);
228         assert(*length <= maxlength);
229
230         cnl = ai->ai_canonname ? strlen(ai->ai_canonname)+1 : 0;
231         l = sizeof(AddrInfoSerialization) + ai->ai_addrlen + cnl;
232
233         if (*length + l > maxlength)
234                 return NULL;
235
236         s.ai_flags = ai->ai_flags;
237         s.ai_family = ai->ai_family;
238         s.ai_socktype = ai->ai_socktype;
239         s.ai_protocol = ai->ai_protocol;
240         s.ai_addrlen = ai->ai_addrlen;
241         s.canonname_len = cnl;
242
243         memcpy((uint8_t*) p, &s, sizeof(AddrInfoSerialization));
244         memcpy((uint8_t*) p + sizeof(AddrInfoSerialization), ai->ai_addr, ai->ai_addrlen);
245
246         if (ai->ai_canonname)
247                 memcpy((char*) p + sizeof(AddrInfoSerialization) + ai->ai_addrlen, ai->ai_canonname, cnl);
248
249         *length += l;
250         return (uint8_t*) p + l;
251 }
252
253 static int send_addrinfo_reply(
254                 int out_fd,
255                 unsigned id,
256                 int ret,
257                 struct addrinfo *ai,
258                 int _errno,
259                 int _h_errno) {
260
261         AddrInfoResponse resp = {
262                 .header.type = RESPONSE_ADDRINFO,
263                 .header.id = id,
264                 .header.length = sizeof(AddrInfoResponse),
265                 .ret = ret,
266                 ._errno = _errno,
267                 ._h_errno = _h_errno,
268         };
269
270         struct msghdr mh = {};
271         struct iovec iov[2];
272         union {
273                 AddrInfoSerialization ais;
274                 uint8_t space[BUFSIZE];
275         } buffer;
276
277         assert(out_fd >= 0);
278
279         if (ret == 0 && ai) {
280                 void *p = &buffer;
281                 struct addrinfo *k;
282
283                 for (k = ai; k; k = k->ai_next) {
284                         p = serialize_addrinfo(p, k, &resp.header.length, (uint8_t*) &buffer + BUFSIZE - (uint8_t*) p);
285                         if (!p) {
286                                 freeaddrinfo(ai);
287                                 return -ENOBUFS;
288                         }
289                 }
290         }
291
292         if (ai)
293                 freeaddrinfo(ai);
294
295         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(AddrInfoResponse) };
296         iov[1] = (struct iovec) { .iov_base = &buffer, .iov_len = resp.header.length - sizeof(AddrInfoResponse) };
297
298         mh.msg_iov = iov;
299         mh.msg_iovlen = ELEMENTSOF(iov);
300
301         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
302                 return -errno;
303
304         return 0;
305 }
306
307 static int send_nameinfo_reply(
308                 int out_fd,
309                 unsigned id,
310                 int ret,
311                 const char *host,
312                 const char *serv,
313                 int _errno,
314                 int _h_errno) {
315
316         NameInfoResponse resp = {
317                 .header.type = RESPONSE_NAMEINFO,
318                 .header.id = id,
319                 .ret = ret,
320                 ._errno = _errno,
321                 ._h_errno = _h_errno,
322         };
323
324         struct msghdr mh = {};
325         struct iovec iov[3];
326         size_t hl, sl;
327
328         assert(out_fd >= 0);
329
330         sl = serv ? strlen(serv)+1 : 0;
331         hl = host ? strlen(host)+1 : 0;
332
333         resp.header.length = sizeof(NameInfoResponse) + hl + sl;
334         resp.hostlen = hl;
335         resp.servlen = sl;
336
337         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(NameInfoResponse) };
338         iov[1] = (struct iovec) { .iov_base = (void*) host, .iov_len = hl };
339         iov[2] = (struct iovec) { .iov_base = (void*) serv, .iov_len = sl };
340
341         mh.msg_iov = iov;
342         mh.msg_iovlen = ELEMENTSOF(iov);
343
344         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
345                 return -errno;
346
347         return 0;
348 }
349
350 static int send_res_reply(int out_fd, unsigned id, const unsigned char *answer, int ret, int _errno, int _h_errno) {
351
352         ResResponse resp = {
353                 .header.type = RESPONSE_RES,
354                 .header.id = id,
355                 .ret = ret,
356                 ._errno = _errno,
357                 ._h_errno = _h_errno,
358         };
359
360         struct msghdr mh = {};
361         struct iovec iov[2];
362         size_t l;
363
364         assert(out_fd >= 0);
365
366         l = ret > 0 ? (size_t) ret : 0;
367
368         resp.header.length = sizeof(ResResponse) + l;
369
370         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(ResResponse) };
371         iov[1] = (struct iovec) { .iov_base = (void*) answer, .iov_len = l };
372
373         mh.msg_iov = iov;
374         mh.msg_iovlen = ELEMENTSOF(iov);
375
376         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
377                 return -errno;
378
379         return 0;
380 }
381
382 static int handle_request(int out_fd, const Packet *packet, size_t length) {
383         const RHeader *req;
384
385         assert(out_fd >= 0);
386         assert(packet);
387
388         req = &packet->rheader;
389
390         assert(length >= sizeof(RHeader));
391         assert(length == req->length);
392
393         switch (req->type) {
394
395         case REQUEST_ADDRINFO: {
396                const AddrInfoRequest *ai_req = &packet->addrinfo_request;
397                struct addrinfo hints = {}, *result = NULL;
398                const char *node, *service;
399                int ret;
400
401                assert(length >= sizeof(AddrInfoRequest));
402                assert(length == sizeof(AddrInfoRequest) + ai_req->node_len + ai_req->service_len);
403
404                hints.ai_flags = ai_req->ai_flags;
405                hints.ai_family = ai_req->ai_family;
406                hints.ai_socktype = ai_req->ai_socktype;
407                hints.ai_protocol = ai_req->ai_protocol;
408
409                node = ai_req->node_len ? (const char*) ai_req + sizeof(AddrInfoRequest) : NULL;
410                service = ai_req->service_len ? (const char*) ai_req + sizeof(AddrInfoRequest) + ai_req->node_len : NULL;
411
412                ret = getaddrinfo(
413                                node, service,
414                                ai_req->hints_valid ? &hints : NULL,
415                                &result);
416
417                /* send_addrinfo_reply() frees result */
418                return send_addrinfo_reply(out_fd, req->id, ret, result, errno, h_errno);
419         }
420
421         case REQUEST_NAMEINFO: {
422                const NameInfoRequest *ni_req = &packet->nameinfo_request;
423                char hostbuf[NI_MAXHOST], servbuf[NI_MAXSERV];
424                union sockaddr_union sa;
425                int ret;
426
427                assert(length >= sizeof(NameInfoRequest));
428                assert(length == sizeof(NameInfoRequest) + ni_req->sockaddr_len);
429                assert(sizeof(sa) >= ni_req->sockaddr_len);
430
431                memcpy(&sa, (const uint8_t *) ni_req + sizeof(NameInfoRequest), ni_req->sockaddr_len);
432
433                ret = getnameinfo(&sa.sa, ni_req->sockaddr_len,
434                                ni_req->gethost ? hostbuf : NULL, ni_req->gethost ? sizeof(hostbuf) : 0,
435                                ni_req->getserv ? servbuf : NULL, ni_req->getserv ? sizeof(servbuf) : 0,
436                                ni_req->flags);
437
438                return send_nameinfo_reply(out_fd, req->id, ret,
439                                ret == 0 && ni_req->gethost ? hostbuf : NULL,
440                                ret == 0 && ni_req->getserv ? servbuf : NULL,
441                                errno, h_errno);
442         }
443
444         case REQUEST_RES_QUERY:
445         case REQUEST_RES_SEARCH: {
446                  const ResRequest *res_req = &packet->res_request;
447                  union {
448                          HEADER header;
449                          uint8_t space[BUFSIZE];
450                  } answer;
451                  const char *dname;
452                  int ret;
453
454                  assert(length >= sizeof(ResRequest));
455                  assert(length == sizeof(ResRequest) + res_req->dname_len);
456
457                  dname = (const char *) req + sizeof(ResRequest);
458
459                  if (req->type == REQUEST_RES_QUERY)
460                          ret = res_query(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
461                  else
462                          ret = res_search(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
463
464                  return send_res_reply(out_fd, req->id, (unsigned char *) &answer, ret, errno, h_errno);
465         }
466
467         case REQUEST_TERMINATE:
468                  /* Quit */
469                  return -ECONNRESET;
470
471         default:
472                 assert_not_reached("Unknown request");
473         }
474
475         return 0;
476 }
477
478 static void* thread_worker(void *p) {
479         sd_resolve *resolve = p;
480         sigset_t fullset;
481
482         /* No signals in this thread please */
483         assert_se(sigfillset(&fullset) == 0);
484         assert_se(pthread_sigmask(SIG_BLOCK, &fullset, NULL) == 0);
485
486         /* Assign a pretty name to this thread */
487         prctl(PR_SET_NAME, (unsigned long) "sd-resolve");
488
489         while (!resolve->dead) {
490                 union {
491                         Packet packet;
492                         uint8_t space[BUFSIZE];
493                 } buf;
494                 ssize_t length;
495
496                 length = recv(resolve->fds[REQUEST_RECV_FD], &buf, sizeof(buf), 0);
497                 if (length < 0) {
498                         if (errno == EINTR)
499                                 continue;
500
501                         break;
502                 }
503                 if (length == 0)
504                         break;
505
506                 if (resolve->dead)
507                         break;
508
509                 if (handle_request(resolve->fds[RESPONSE_SEND_FD], &buf.packet, (size_t) length) < 0)
510                         break;
511         }
512
513         send_died(resolve->fds[RESPONSE_SEND_FD]);
514
515         return NULL;
516 }
517
518 static int start_threads(sd_resolve *resolve, unsigned extra) {
519         unsigned n;
520         int r;
521
522         n = resolve->n_queries + extra - resolve->n_done;
523         n = CLAMP(n, WORKERS_MIN, WORKERS_MAX);
524
525         while (resolve->n_valid_workers < n) {
526
527                 r = pthread_create(&resolve->workers[resolve->n_valid_workers], NULL, thread_worker, resolve);
528                 if (r != 0)
529                         return -r;
530
531                 resolve->n_valid_workers ++;
532         }
533
534         return 0;
535 }
536
537 static bool resolve_pid_changed(sd_resolve *r) {
538         assert(r);
539
540         /* We don't support people creating a resolver and keeping it
541          * around after fork(). Let's complain. */
542
543         return r->original_pid != getpid();
544 }
545
546 _public_ int sd_resolve_new(sd_resolve **ret) {
547         sd_resolve *resolve = NULL;
548         int i, r;
549
550         assert_return(ret, -EINVAL);
551
552         resolve = new0(sd_resolve, 1);
553         if (!resolve)
554                 return -ENOMEM;
555
556         resolve->n_ref = 1;
557         resolve->original_pid = getpid();
558
559         for (i = 0; i < _FD_MAX; i++)
560                 resolve->fds[i] = -1;
561
562         r = socketpair(PF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0, resolve->fds + REQUEST_RECV_FD);
563         if (r < 0) {
564                 r = -errno;
565                 goto fail;
566         }
567
568         r = socketpair(PF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0, resolve->fds + RESPONSE_RECV_FD);
569         if (r < 0) {
570                 r = -errno;
571                 goto fail;
572         }
573
574         fd_inc_sndbuf(resolve->fds[REQUEST_SEND_FD], QUERIES_MAX * BUFSIZE);
575         fd_inc_rcvbuf(resolve->fds[REQUEST_RECV_FD], QUERIES_MAX * BUFSIZE);
576         fd_inc_sndbuf(resolve->fds[RESPONSE_SEND_FD], QUERIES_MAX * BUFSIZE);
577         fd_inc_rcvbuf(resolve->fds[RESPONSE_RECV_FD], QUERIES_MAX * BUFSIZE);
578
579         fd_nonblock(resolve->fds[RESPONSE_RECV_FD], true);
580
581         *ret = resolve;
582         return 0;
583
584 fail:
585         sd_resolve_unref(resolve);
586         return r;
587 }
588
589 _public_ int sd_resolve_default(sd_resolve **ret) {
590
591         static thread_local sd_resolve *default_resolve = NULL;
592         sd_resolve *e = NULL;
593         int r;
594
595         if (!ret)
596                 return !!default_resolve;
597
598         if (default_resolve) {
599                 *ret = sd_resolve_ref(default_resolve);
600                 return 0;
601         }
602
603         r = sd_resolve_new(&e);
604         if (r < 0)
605                 return r;
606
607         e->default_resolve_ptr = &default_resolve;
608         e->tid = gettid();
609         default_resolve = e;
610
611         *ret = e;
612         return 1;
613 }
614
615 _public_ int sd_resolve_get_tid(sd_resolve *resolve, pid_t *tid) {
616         assert_return(resolve, -EINVAL);
617         assert_return(tid, -EINVAL);
618         assert_return(!resolve_pid_changed(resolve), -ECHILD);
619
620         if (resolve->tid != 0) {
621                 *tid = resolve->tid;
622                 return 0;
623         }
624
625         if (resolve->event)
626                 return sd_event_get_tid(resolve->event, tid);
627
628         return -ENXIO;
629 }
630
631 static void resolve_free(sd_resolve *resolve) {
632         PROTECT_ERRNO;
633         unsigned i;
634
635         assert(resolve);
636
637         if (resolve->default_resolve_ptr)
638                 *(resolve->default_resolve_ptr) = NULL;
639
640         resolve->dead = true;
641
642         sd_resolve_detach_event(resolve);
643
644         if (resolve->fds[REQUEST_SEND_FD] >= 0) {
645
646                 RHeader req = {
647                         .type = REQUEST_TERMINATE,
648                         .length = sizeof(req)
649                 };
650
651                 /* Send one termination packet for each worker */
652                 for (i = 0; i < resolve->n_valid_workers; i++)
653                         send(resolve->fds[REQUEST_SEND_FD], &req, req.length, MSG_NOSIGNAL);
654         }
655
656         /* Now terminate them and wait until they are gone. */
657         for (i = 0; i < resolve->n_valid_workers; i++) {
658                 for (;;) {
659                         if (pthread_join(resolve->workers[i], NULL) != EINTR)
660                                 break;
661                 }
662         }
663
664         /* Close all communication channels */
665         for (i = 0; i < _FD_MAX; i++)
666                 safe_close(resolve->fds[i]);
667
668         free(resolve);
669 }
670
671 _public_ sd_resolve* sd_resolve_ref(sd_resolve *resolve) {
672         assert_return(resolve, NULL);
673
674         assert(resolve->n_ref >= 1);
675         resolve->n_ref++;
676
677         return resolve;
678 }
679
680 _public_ sd_resolve* sd_resolve_unref(sd_resolve *resolve) {
681
682         if (!resolve)
683                 return NULL;
684
685         assert(resolve->n_ref >= 1);
686         resolve->n_ref--;
687
688         if (resolve->n_ref <= 0)
689                 resolve_free(resolve);
690
691         return NULL;
692 }
693
694 _public_ int sd_resolve_get_fd(sd_resolve *resolve) {
695         assert_return(resolve, -EINVAL);
696         assert_return(!resolve_pid_changed(resolve), -ECHILD);
697
698         return resolve->fds[RESPONSE_RECV_FD];
699 }
700
701 _public_ int sd_resolve_get_events(sd_resolve *resolve) {
702         assert_return(resolve, -EINVAL);
703         assert_return(!resolve_pid_changed(resolve), -ECHILD);
704
705         return resolve->n_queries > resolve->n_done ? POLLIN : 0;
706 }
707
708 _public_ int sd_resolve_get_timeout(sd_resolve *resolve, uint64_t *usec) {
709         assert_return(resolve, -EINVAL);
710         assert_return(usec, -EINVAL);
711         assert_return(!resolve_pid_changed(resolve), -ECHILD);
712
713         *usec = (uint64_t) -1;
714         return 0;
715 }
716
717 static sd_resolve_query *lookup_query(sd_resolve *resolve, unsigned id) {
718         sd_resolve_query *q;
719
720         assert(resolve);
721
722         q = resolve->queries[id % QUERIES_MAX];
723         if (q)
724                 if (q->id == id)
725                         return q;
726
727         return NULL;
728 }
729
730 static int complete_query(sd_resolve *resolve, sd_resolve_query *q) {
731         int r;
732
733         assert(q);
734         assert(!q->done);
735         assert(q->resolve == resolve);
736
737         q->done = true;
738         resolve->n_done ++;
739
740         resolve->current = q;
741
742         switch (q->type) {
743
744         case REQUEST_ADDRINFO:
745                 r = getaddrinfo_done(q);
746                 break;
747
748         case REQUEST_NAMEINFO:
749                 r = getnameinfo_done(q);
750                 break;
751
752         case REQUEST_RES_QUERY:
753         case REQUEST_RES_SEARCH:
754                 r = res_query_done(q);
755                 break;
756
757         default:
758                 assert_not_reached("Cannot complete unknown query type");
759         }
760
761         resolve->current = NULL;
762
763         return r;
764 }
765
766 static int unserialize_addrinfo(const void **p, size_t *length, struct addrinfo **ret_ai) {
767         AddrInfoSerialization s;
768         size_t l;
769         struct addrinfo *ai;
770
771         assert(p);
772         assert(*p);
773         assert(ret_ai);
774         assert(length);
775
776         if (*length < sizeof(AddrInfoSerialization))
777                 return -EBADMSG;
778
779         memcpy(&s, *p, sizeof(s));
780
781         l = sizeof(AddrInfoSerialization) + s.ai_addrlen + s.canonname_len;
782         if (*length < l)
783                 return -EBADMSG;
784
785         ai = new0(struct addrinfo, 1);
786         if (!ai)
787                 return -ENOMEM;
788
789         ai->ai_flags = s.ai_flags;
790         ai->ai_family = s.ai_family;
791         ai->ai_socktype = s.ai_socktype;
792         ai->ai_protocol = s.ai_protocol;
793         ai->ai_addrlen = s.ai_addrlen;
794
795         if (s.ai_addrlen > 0) {
796                 ai->ai_addr = memdup((const uint8_t*) *p + sizeof(AddrInfoSerialization), s.ai_addrlen);
797                 if (!ai->ai_addr) {
798                         free(ai);
799                         return -ENOMEM;
800                 }
801         }
802
803         if (s.canonname_len > 0) {
804                 ai->ai_canonname = memdup((const uint8_t*) *p + sizeof(AddrInfoSerialization) + s.ai_addrlen, s.canonname_len);
805                 if (!ai->ai_canonname) {
806                         free(ai->ai_addr);
807                         free(ai);
808                         return -ENOMEM;
809                 }
810         }
811
812         *length -= l;
813         *ret_ai = ai;
814         *p = ((const uint8_t*) *p) + l;
815
816         return 0;
817 }
818
819 static int handle_response(sd_resolve *resolve, const Packet *packet, size_t length) {
820         const RHeader *resp;
821         sd_resolve_query *q;
822         int r;
823
824         assert(resolve);
825
826         resp = &packet->rheader;
827         assert(resp);
828         assert(length >= sizeof(RHeader));
829         assert(length == resp->length);
830
831         if (resp->type == RESPONSE_DIED) {
832                 resolve->dead = true;
833                 return 0;
834         }
835
836         q = lookup_query(resolve, resp->id);
837         if (!q)
838                 return 0;
839
840         switch (resp->type) {
841
842         case RESPONSE_ADDRINFO: {
843                 const AddrInfoResponse *ai_resp = &packet->addrinfo_response;
844                 const void *p;
845                 size_t l;
846                 struct addrinfo *prev = NULL;
847
848                 assert(length >= sizeof(AddrInfoResponse));
849                 assert(q->type == REQUEST_ADDRINFO);
850
851                 q->ret = ai_resp->ret;
852                 q->_errno = ai_resp->_errno;
853                 q->_h_errno = ai_resp->_h_errno;
854
855                 l = length - sizeof(AddrInfoResponse);
856                 p = (const uint8_t*) resp + sizeof(AddrInfoResponse);
857
858                 while (l > 0 && p) {
859                         struct addrinfo *ai = NULL;
860
861                         r = unserialize_addrinfo(&p, &l, &ai);
862                         if (r < 0) {
863                                 q->ret = EAI_SYSTEM;
864                                 q->_errno = -r;
865                                 q->_h_errno = 0;
866                                 freeaddrinfo(q->addrinfo);
867                                 q->addrinfo = NULL;
868                                 break;
869                         }
870
871                         if (prev)
872                                 prev->ai_next = ai;
873                         else
874                                 q->addrinfo = ai;
875
876                         prev = ai;
877                 }
878
879                 return complete_query(resolve, q);
880         }
881
882         case RESPONSE_NAMEINFO: {
883                 const NameInfoResponse *ni_resp = &packet->nameinfo_response;
884
885                 assert(length >= sizeof(NameInfoResponse));
886                 assert(q->type == REQUEST_NAMEINFO);
887
888                 q->ret = ni_resp->ret;
889                 q->_errno = ni_resp->_errno;
890                 q->_h_errno = ni_resp->_h_errno;
891
892                 if (ni_resp->hostlen > 0) {
893                         q->host = strndup((const char*) ni_resp + sizeof(NameInfoResponse), ni_resp->hostlen-1);
894                         if (!q->host) {
895                                 q->ret = EAI_MEMORY;
896                                 q->_errno = ENOMEM;
897                                 q->_h_errno = 0;
898                         }
899                 }
900
901                 if (ni_resp->servlen > 0) {
902                         q->serv = strndup((const char*) ni_resp + sizeof(NameInfoResponse) + ni_resp->hostlen, ni_resp->servlen-1);
903                         if (!q->serv) {
904                                 q->ret = EAI_MEMORY;
905                                 q->_errno = ENOMEM;
906                                 q->_h_errno = 0;
907                         }
908                 }
909
910                 return complete_query(resolve, q);
911         }
912
913         case RESPONSE_RES: {
914                 const ResResponse *res_resp = &packet->res_response;
915
916                 assert(length >= sizeof(ResResponse));
917                 assert(q->type == REQUEST_RES_QUERY || q->type == REQUEST_RES_SEARCH);
918
919                 q->ret = res_resp->ret;
920                 q->_errno = res_resp->_errno;
921                 q->_h_errno = res_resp->_h_errno;
922
923                 if (res_resp->ret >= 0)  {
924                         q->answer = memdup((const char *)resp + sizeof(ResResponse), res_resp->ret);
925                         if (!q->answer) {
926                                 q->ret = -1;
927                                 q->_errno = ENOMEM;
928                                 q->_h_errno = 0;
929                         }
930                 }
931
932                 return complete_query(resolve, q);
933         }
934
935         default:
936                 return 0;
937         }
938 }
939
940 _public_ int sd_resolve_process(sd_resolve *resolve) {
941         RESOLVE_DONT_DESTROY(resolve);
942
943         union {
944                 Packet packet;
945                 uint8_t space[BUFSIZE];
946         } buf;
947         ssize_t l;
948         int r;
949
950         assert_return(resolve, -EINVAL);
951         assert_return(!resolve_pid_changed(resolve), -ECHILD);
952
953         /* We don't allow recursively invoking sd_resolve_process(). */
954         assert_return(!resolve->current, -EBUSY);
955
956         l = recv(resolve->fds[RESPONSE_RECV_FD], &buf, sizeof(buf), 0);
957         if (l < 0) {
958                 if (errno == EAGAIN)
959                         return 0;
960
961                 return -errno;
962         }
963         if (l == 0)
964                 return -ECONNREFUSED;
965
966         r = handle_response(resolve, &buf.packet, (size_t) l);
967         if (r < 0)
968                 return r;
969
970         return 1;
971 }
972
973 _public_ int sd_resolve_wait(sd_resolve *resolve, uint64_t timeout_usec) {
974         int r;
975
976         assert_return(resolve, -EINVAL);
977         assert_return(!resolve_pid_changed(resolve), -ECHILD);
978
979         if (resolve->n_done >= resolve->n_queries)
980                 return 0;
981
982         do {
983                 r = fd_wait_for_event(resolve->fds[RESPONSE_RECV_FD], POLLIN, timeout_usec);
984         } while (r == -EINTR);
985
986         if (r < 0)
987                 return r;
988
989         return sd_resolve_process(resolve);
990 }
991
992 static int alloc_query(sd_resolve *resolve, sd_resolve_query **_q) {
993         sd_resolve_query *q;
994         int r;
995
996         assert(resolve);
997         assert(_q);
998
999         if (resolve->n_queries >= QUERIES_MAX)
1000                 return -ENOBUFS;
1001
1002         r = start_threads(resolve, 1);
1003         if (r < 0)
1004                 return r;
1005
1006         while (resolve->queries[resolve->current_index]) {
1007                 resolve->current_index++;
1008                 resolve->current_id++;
1009
1010                 resolve->current_index %= QUERIES_MAX;
1011         }
1012
1013         q = resolve->queries[resolve->current_index] = new0(sd_resolve_query, 1);
1014         if (!q)
1015                 return -ENOMEM;
1016
1017         q->n_ref = 1;
1018         q->resolve = sd_resolve_ref(resolve);
1019         q->id = resolve->current_id;
1020
1021         resolve->n_queries++;
1022
1023         *_q = q;
1024         return 0;
1025 }
1026
1027 _public_ int sd_resolve_getaddrinfo(
1028                 sd_resolve *resolve,
1029                 sd_resolve_query **_q,
1030                 const char *node, const char *service,
1031                 const struct addrinfo *hints,
1032                 sd_resolve_getaddrinfo_handler_t callback, void *userdata) {
1033
1034         AddrInfoRequest req = {};
1035         struct msghdr mh = {};
1036         struct iovec iov[3];
1037         sd_resolve_query *q;
1038         int r;
1039
1040         assert_return(resolve, -EINVAL);
1041         assert_return(_q, -EINVAL);
1042         assert_return(node || service, -EINVAL);
1043         assert_return(callback, -EINVAL);
1044         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1045
1046         r = alloc_query(resolve, &q);
1047         if (r < 0)
1048                 return r;
1049
1050         q->type = REQUEST_ADDRINFO;
1051         q->getaddrinfo_handler = callback;
1052         q->userdata = userdata;
1053
1054         req.node_len = node ? strlen(node)+1 : 0;
1055         req.service_len = service ? strlen(service)+1 : 0;
1056
1057         req.header.id = q->id;
1058         req.header.type = REQUEST_ADDRINFO;
1059         req.header.length = sizeof(AddrInfoRequest) + req.node_len + req.service_len;
1060
1061         if (hints) {
1062                 req.hints_valid = true;
1063                 req.ai_flags = hints->ai_flags;
1064                 req.ai_family = hints->ai_family;
1065                 req.ai_socktype = hints->ai_socktype;
1066                 req.ai_protocol = hints->ai_protocol;
1067         }
1068
1069         iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(AddrInfoRequest) };
1070         if (node)
1071                 iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = (void*) node, .iov_len = req.node_len };
1072         if (service)
1073                 iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = (void*) service, .iov_len = req.service_len };
1074         mh.msg_iov = iov;
1075
1076         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1077                 sd_resolve_query_unref(q);
1078                 return -errno;
1079         }
1080
1081         *_q = q;
1082         return 0;
1083 }
1084
1085 static int getaddrinfo_done(sd_resolve_query* q) {
1086         assert(q);
1087         assert(q->done);
1088         assert(q->getaddrinfo_handler);
1089
1090         errno = q->_errno;
1091         h_errno = q->_h_errno;
1092
1093         return q->getaddrinfo_handler(q, q->ret, q->addrinfo, q->userdata);
1094 }
1095
1096 _public_ int sd_resolve_getnameinfo(
1097                 sd_resolve *resolve,
1098                 sd_resolve_query**_q,
1099                 const struct sockaddr *sa, socklen_t salen,
1100                 int flags,
1101                 uint64_t get,
1102                 sd_resolve_getnameinfo_handler_t callback,
1103                 void *userdata) {
1104
1105         NameInfoRequest req = {};
1106         struct msghdr mh = {};
1107         struct iovec iov[2];
1108         sd_resolve_query *q;
1109         int r;
1110
1111         assert_return(resolve, -EINVAL);
1112         assert_return(_q, -EINVAL);
1113         assert_return(sa, -EINVAL);
1114         assert_return(salen >= sizeof(struct sockaddr), -EINVAL);
1115         assert_return(salen <= sizeof(union sockaddr_union), -EINVAL);
1116         assert_return((get & ~SD_RESOLVE_GET_BOTH) == 0, -EINVAL);
1117         assert_return(callback, -EINVAL);
1118         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1119
1120         r = alloc_query(resolve, &q);
1121         if (r < 0)
1122                 return r;
1123
1124         q->type = REQUEST_NAMEINFO;
1125         q->getnameinfo_handler = callback;
1126         q->userdata = userdata;
1127
1128         req.header.id = q->id;
1129         req.header.type = REQUEST_NAMEINFO;
1130         req.header.length = sizeof(NameInfoRequest) + salen;
1131
1132         req.flags = flags;
1133         req.sockaddr_len = salen;
1134         req.gethost = !!(get & SD_RESOLVE_GET_HOST);
1135         req.getserv = !!(get & SD_RESOLVE_GET_SERVICE);
1136
1137         iov[0] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(NameInfoRequest) };
1138         iov[1] = (struct iovec) { .iov_base = (void*) sa, .iov_len = salen };
1139
1140         mh.msg_iov = iov;
1141         mh.msg_iovlen = 2;
1142
1143         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1144                 sd_resolve_query_unref(q);
1145                 return -errno;
1146         }
1147
1148         *_q = q;
1149         return 0;
1150 }
1151
1152 static int getnameinfo_done(sd_resolve_query *q) {
1153
1154         assert(q);
1155         assert(q->done);
1156         assert(q->getnameinfo_handler);
1157
1158         errno = q->_errno;
1159         h_errno= q->_h_errno;
1160
1161         return q->getnameinfo_handler(q, q->ret, q->host, q->serv, q->userdata);
1162 }
1163
1164 static int resolve_res(
1165                 sd_resolve *resolve,
1166                 sd_resolve_query **_q,
1167                 QueryType qtype,
1168                 const char *dname,
1169                 int class, int type,
1170                 sd_resolve_res_handler_t callback, void *userdata) {
1171
1172         struct msghdr mh = {};
1173         struct iovec iov[2];
1174         ResRequest req = {};
1175         sd_resolve_query *q;
1176         int r;
1177
1178         assert_return(resolve, -EINVAL);
1179         assert_return(_q, -EINVAL);
1180         assert_return(dname, -EINVAL);
1181         assert_return(callback, -EINVAL);
1182         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1183
1184         r = alloc_query(resolve, &q);
1185         if (r < 0)
1186                 return r;
1187
1188         q->type = qtype;
1189         q->res_handler = callback;
1190         q->userdata = userdata;
1191
1192         req.dname_len = strlen(dname) + 1;
1193         req.class = class;
1194         req.type = type;
1195
1196         req.header.id = q->id;
1197         req.header.type = qtype;
1198         req.header.length = sizeof(ResRequest) + req.dname_len;
1199
1200         iov[0] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(ResRequest) };
1201         iov[1] = (struct iovec) { .iov_base = (void*) dname, .iov_len = req.dname_len };
1202
1203         mh.msg_iov = iov;
1204         mh.msg_iovlen = 2;
1205
1206         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1207                 sd_resolve_query_unref(q);
1208                 return -errno;
1209         }
1210
1211         *_q = q;
1212         return 0;
1213 }
1214
1215 _public_ int sd_resolve_res_query(sd_resolve *resolve, sd_resolve_query** q, const char *dname, int class, int type, sd_resolve_res_handler_t callback, void *userdata) {
1216         return resolve_res(resolve, q, REQUEST_RES_QUERY, dname, class, type, callback, userdata);
1217 }
1218
1219 _public_ int sd_resolve_res_search(sd_resolve *resolve, sd_resolve_query** q, const char *dname, int class, int type, sd_resolve_res_handler_t callback, void *userdata) {
1220         return resolve_res(resolve, q, REQUEST_RES_SEARCH, dname, class, type, callback, userdata);
1221 }
1222
1223 static int res_query_done(sd_resolve_query* q) {
1224         assert(q);
1225         assert(q->done);
1226         assert(q->res_handler);
1227
1228         errno = q->_errno;
1229         h_errno = q->_h_errno;
1230
1231         return q->res_handler(q, q->ret, q->answer, q->userdata);
1232 }
1233
1234 _public_ sd_resolve_query* sd_resolve_query_ref(sd_resolve_query *q) {
1235         assert_return(q, NULL);
1236
1237         assert(q->n_ref >= 1);
1238         q->n_ref++;
1239
1240         return q;
1241 }
1242
1243 static void resolve_freeaddrinfo(struct addrinfo *ai) {
1244         while (ai) {
1245                 struct addrinfo *next = ai->ai_next;
1246
1247                 free(ai->ai_addr);
1248                 free(ai->ai_canonname);
1249                 free(ai);
1250                 ai = next;
1251         }
1252 }
1253
1254 static void resolve_query_free(sd_resolve_query *q) {
1255         unsigned i;
1256
1257         assert(q);
1258         assert(q->resolve);
1259         assert(q->resolve->n_queries > 0);
1260
1261         if (q->done) {
1262                 assert(q->resolve->n_done > 0);
1263                 q->resolve->n_done--;
1264         }
1265
1266         i = q->id % QUERIES_MAX;
1267         assert(q->resolve->queries[i] == q);
1268         q->resolve->queries[i] = NULL;
1269         q->resolve->n_queries--;
1270         sd_resolve_unref(q->resolve);
1271
1272         resolve_freeaddrinfo(q->addrinfo);
1273         free(q->host);
1274         free(q->serv);
1275         free(q->answer);
1276         free(q);
1277 }
1278
1279 _public_ sd_resolve_query* sd_resolve_query_unref(sd_resolve_query* q) {
1280         if (!q)
1281                 return NULL;
1282
1283         assert(q->n_ref >= 1);
1284         q->n_ref--;
1285
1286         if (q->n_ref <= 0)
1287                 resolve_query_free(q);
1288
1289         return NULL;
1290 }
1291
1292 _public_ int sd_resolve_query_is_done(sd_resolve_query *q) {
1293         assert_return(q, -EINVAL);
1294         assert_return(!resolve_pid_changed(q->resolve), -ECHILD);
1295
1296         return q->done;
1297 }
1298
1299 _public_ void* sd_resolve_query_set_userdata(sd_resolve_query *q, void *userdata) {
1300         void *ret;
1301
1302         assert_return(q, NULL);
1303         assert_return(!resolve_pid_changed(q->resolve), NULL);
1304
1305         ret = q->userdata;
1306         q->userdata = userdata;
1307
1308         return ret;
1309 }
1310
1311 _public_ void* sd_resolve_query_get_userdata(sd_resolve_query *q) {
1312         assert_return(q, NULL);
1313         assert_return(!resolve_pid_changed(q->resolve), NULL);
1314
1315         return q->userdata;
1316 }
1317
1318 _public_ sd_resolve *sd_resolve_query_get_resolve(sd_resolve_query *q) {
1319         assert_return(q, NULL);
1320         assert_return(!resolve_pid_changed(q->resolve), NULL);
1321
1322         return q->resolve;
1323 }
1324
1325 static int io_callback(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
1326         sd_resolve *resolve = userdata;
1327         int r;
1328
1329         assert(resolve);
1330
1331         r = sd_resolve_process(resolve);
1332         if (r < 0)
1333                 return r;
1334
1335         return 1;
1336 }
1337
1338 _public_ int sd_resolve_attach_event(sd_resolve *resolve, sd_event *event, int priority) {
1339         int r;
1340
1341         assert_return(resolve, -EINVAL);
1342         assert_return(!resolve->event, -EBUSY);
1343
1344         assert(!resolve->event_source);
1345
1346         if (event)
1347                 resolve->event = sd_event_ref(event);
1348         else {
1349                 r = sd_event_default(&resolve->event);
1350                 if (r < 0)
1351                         return r;
1352         }
1353
1354         r = sd_event_add_io(resolve->event, &resolve->event_source, resolve->fds[RESPONSE_RECV_FD], POLLIN, io_callback, resolve);
1355         if (r < 0)
1356                 goto fail;
1357
1358         r = sd_event_source_set_priority(resolve->event_source, priority);
1359         if (r < 0)
1360                 goto fail;
1361
1362         return 0;
1363
1364 fail:
1365         sd_resolve_detach_event(resolve);
1366         return r;
1367 }
1368
1369 _public_  int sd_resolve_detach_event(sd_resolve *resolve) {
1370         assert_return(resolve, -EINVAL);
1371
1372         if (!resolve->event)
1373                 return 0;
1374
1375         if (resolve->event_source) {
1376                 sd_event_source_set_enabled(resolve->event_source, SD_EVENT_OFF);
1377                 resolve->event_source = sd_event_source_unref(resolve->event_source);
1378         }
1379
1380         resolve->event = sd_event_unref(resolve->event);
1381         return 1;
1382 }
1383
1384 _public_ sd_event *sd_resolve_get_event(sd_resolve *resolve) {
1385         assert_return(resolve, NULL);
1386
1387         return resolve->event;
1388 }