chiark / gitweb /
sysusers: fix uninitialized warning
[elogind.git] / src / libsystemd / sd-resolve / sd-resolve.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2005-2008 Lennart Poettering
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <assert.h>
23 #include <fcntl.h>
24 #include <signal.h>
25 #include <unistd.h>
26 #include <sys/select.h>
27 #include <stdio.h>
28 #include <string.h>
29 #include <stdlib.h>
30 #include <errno.h>
31 #include <sys/wait.h>
32 #include <sys/types.h>
33 #include <pwd.h>
34 #include <netinet/in.h>
35 #include <arpa/nameser.h>
36 #include <resolv.h>
37 #include <dirent.h>
38 #include <sys/time.h>
39 #include <sys/resource.h>
40 #include <stdint.h>
41 #include <pthread.h>
42 #include <sys/prctl.h>
43 #include <sys/poll.h>
44
45 #include "util.h"
46 #include "list.h"
47 #include "socket-util.h"
48 #include "missing.h"
49 #include "resolve-util.h"
50 #include "sd-resolve.h"
51
52 #define WORKERS_MIN 1U
53 #define WORKERS_MAX 16U
54 #define QUERIES_MAX 256U
55 #define BUFSIZE 10240U
56
57 typedef enum {
58         REQUEST_ADDRINFO,
59         RESPONSE_ADDRINFO,
60         REQUEST_NAMEINFO,
61         RESPONSE_NAMEINFO,
62         REQUEST_RES_QUERY,
63         REQUEST_RES_SEARCH,
64         RESPONSE_RES,
65         REQUEST_TERMINATE,
66         RESPONSE_DIED
67 } QueryType;
68
69 enum {
70         REQUEST_RECV_FD,
71         REQUEST_SEND_FD,
72         RESPONSE_RECV_FD,
73         RESPONSE_SEND_FD,
74         _FD_MAX
75 };
76
77 struct sd_resolve {
78         unsigned n_ref;
79
80         bool dead:1;
81         pid_t original_pid;
82
83         int fds[_FD_MAX];
84
85         pthread_t workers[WORKERS_MAX];
86         unsigned n_valid_workers;
87
88         unsigned current_id, current_index;
89         sd_resolve_query* query_array[QUERIES_MAX];
90         unsigned n_queries, n_done;
91
92         sd_event_source *event_source;
93         sd_event *event;
94
95         sd_resolve_query *current;
96
97         sd_resolve **default_resolve_ptr;
98         pid_t tid;
99
100         LIST_HEAD(sd_resolve_query, queries);
101 };
102
103 struct sd_resolve_query {
104         unsigned n_ref;
105
106         sd_resolve *resolve;
107
108         QueryType type:4;
109         bool done:1;
110         bool floating:1;
111         unsigned id;
112
113         int ret;
114         int _errno;
115         int _h_errno;
116         struct addrinfo *addrinfo;
117         char *serv, *host;
118         unsigned char *answer;
119
120         union {
121                 sd_resolve_getaddrinfo_handler_t getaddrinfo_handler;
122                 sd_resolve_getnameinfo_handler_t getnameinfo_handler;
123                 sd_resolve_res_handler_t res_handler;
124         };
125
126         void *userdata;
127
128         LIST_FIELDS(sd_resolve_query, queries);
129 };
130
131 typedef struct RHeader {
132         QueryType type;
133         unsigned id;
134         size_t length;
135 } RHeader;
136
137 typedef struct AddrInfoRequest {
138         struct RHeader header;
139         bool hints_valid;
140         int ai_flags;
141         int ai_family;
142         int ai_socktype;
143         int ai_protocol;
144         size_t node_len, service_len;
145 } AddrInfoRequest;
146
147 typedef struct AddrInfoResponse {
148         struct RHeader header;
149         int ret;
150         int _errno;
151         int _h_errno;
152         /* followed by addrinfo_serialization[] */
153 } AddrInfoResponse;
154
155 typedef struct AddrInfoSerialization {
156         int ai_flags;
157         int ai_family;
158         int ai_socktype;
159         int ai_protocol;
160         size_t ai_addrlen;
161         size_t canonname_len;
162         /* Followed by ai_addr amd ai_canonname with variable lengths */
163 } AddrInfoSerialization;
164
165 typedef struct NameInfoRequest {
166         struct RHeader header;
167         int flags;
168         socklen_t sockaddr_len;
169         bool gethost:1, getserv:1;
170 } NameInfoRequest;
171
172 typedef struct NameInfoResponse {
173         struct RHeader header;
174         size_t hostlen, servlen;
175         int ret;
176         int _errno;
177         int _h_errno;
178 } NameInfoResponse;
179
180 typedef struct ResRequest {
181         struct RHeader header;
182         int class;
183         int type;
184         size_t dname_len;
185 } ResRequest;
186
187 typedef struct ResResponse {
188         struct RHeader header;
189         int ret;
190         int _errno;
191         int _h_errno;
192 } ResResponse;
193
194 typedef union Packet {
195         RHeader rheader;
196         AddrInfoRequest addrinfo_request;
197         AddrInfoResponse addrinfo_response;
198         NameInfoRequest nameinfo_request;
199         NameInfoResponse nameinfo_response;
200         ResRequest res_request;
201         ResResponse res_response;
202 } Packet;
203
204 static int getaddrinfo_done(sd_resolve_query* q);
205 static int getnameinfo_done(sd_resolve_query *q);
206 static int res_query_done(sd_resolve_query* q);
207
208 static void resolve_query_disconnect(sd_resolve_query *q);
209
210 #define RESOLVE_DONT_DESTROY(resolve) \
211         _cleanup_resolve_unref_ _unused_ sd_resolve *_dont_destroy_##resolve = sd_resolve_ref(resolve)
212
213 static int send_died(int out_fd) {
214
215         RHeader rh = {
216                 .type = RESPONSE_DIED,
217                 .length = sizeof(RHeader),
218         };
219
220         assert(out_fd >= 0);
221
222         if (send(out_fd, &rh, rh.length, MSG_NOSIGNAL) < 0)
223                 return -errno;
224
225         return 0;
226 }
227
228 static void *serialize_addrinfo(void *p, const struct addrinfo *ai, size_t *length, size_t maxlength) {
229         AddrInfoSerialization s;
230         size_t cnl, l;
231
232         assert(p);
233         assert(ai);
234         assert(length);
235         assert(*length <= maxlength);
236
237         cnl = ai->ai_canonname ? strlen(ai->ai_canonname)+1 : 0;
238         l = sizeof(AddrInfoSerialization) + ai->ai_addrlen + cnl;
239
240         if (*length + l > maxlength)
241                 return NULL;
242
243         s.ai_flags = ai->ai_flags;
244         s.ai_family = ai->ai_family;
245         s.ai_socktype = ai->ai_socktype;
246         s.ai_protocol = ai->ai_protocol;
247         s.ai_addrlen = ai->ai_addrlen;
248         s.canonname_len = cnl;
249
250         memcpy((uint8_t*) p, &s, sizeof(AddrInfoSerialization));
251         memcpy((uint8_t*) p + sizeof(AddrInfoSerialization), ai->ai_addr, ai->ai_addrlen);
252
253         if (ai->ai_canonname)
254                 memcpy((char*) p + sizeof(AddrInfoSerialization) + ai->ai_addrlen, ai->ai_canonname, cnl);
255
256         *length += l;
257         return (uint8_t*) p + l;
258 }
259
260 static int send_addrinfo_reply(
261                 int out_fd,
262                 unsigned id,
263                 int ret,
264                 struct addrinfo *ai,
265                 int _errno,
266                 int _h_errno) {
267
268         AddrInfoResponse resp = {
269                 .header.type = RESPONSE_ADDRINFO,
270                 .header.id = id,
271                 .header.length = sizeof(AddrInfoResponse),
272                 .ret = ret,
273                 ._errno = _errno,
274                 ._h_errno = _h_errno,
275         };
276
277         struct msghdr mh = {};
278         struct iovec iov[2];
279         union {
280                 AddrInfoSerialization ais;
281                 uint8_t space[BUFSIZE];
282         } buffer;
283
284         assert(out_fd >= 0);
285
286         if (ret == 0 && ai) {
287                 void *p = &buffer;
288                 struct addrinfo *k;
289
290                 for (k = ai; k; k = k->ai_next) {
291                         p = serialize_addrinfo(p, k, &resp.header.length, (uint8_t*) &buffer + BUFSIZE - (uint8_t*) p);
292                         if (!p) {
293                                 freeaddrinfo(ai);
294                                 return -ENOBUFS;
295                         }
296                 }
297         }
298
299         if (ai)
300                 freeaddrinfo(ai);
301
302         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(AddrInfoResponse) };
303         iov[1] = (struct iovec) { .iov_base = &buffer, .iov_len = resp.header.length - sizeof(AddrInfoResponse) };
304
305         mh.msg_iov = iov;
306         mh.msg_iovlen = ELEMENTSOF(iov);
307
308         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
309                 return -errno;
310
311         return 0;
312 }
313
314 static int send_nameinfo_reply(
315                 int out_fd,
316                 unsigned id,
317                 int ret,
318                 const char *host,
319                 const char *serv,
320                 int _errno,
321                 int _h_errno) {
322
323         NameInfoResponse resp = {
324                 .header.type = RESPONSE_NAMEINFO,
325                 .header.id = id,
326                 .ret = ret,
327                 ._errno = _errno,
328                 ._h_errno = _h_errno,
329         };
330
331         struct msghdr mh = {};
332         struct iovec iov[3];
333         size_t hl, sl;
334
335         assert(out_fd >= 0);
336
337         sl = serv ? strlen(serv)+1 : 0;
338         hl = host ? strlen(host)+1 : 0;
339
340         resp.header.length = sizeof(NameInfoResponse) + hl + sl;
341         resp.hostlen = hl;
342         resp.servlen = sl;
343
344         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(NameInfoResponse) };
345         iov[1] = (struct iovec) { .iov_base = (void*) host, .iov_len = hl };
346         iov[2] = (struct iovec) { .iov_base = (void*) serv, .iov_len = sl };
347
348         mh.msg_iov = iov;
349         mh.msg_iovlen = ELEMENTSOF(iov);
350
351         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
352                 return -errno;
353
354         return 0;
355 }
356
357 static int send_res_reply(int out_fd, unsigned id, const unsigned char *answer, int ret, int _errno, int _h_errno) {
358
359         ResResponse resp = {
360                 .header.type = RESPONSE_RES,
361                 .header.id = id,
362                 .ret = ret,
363                 ._errno = _errno,
364                 ._h_errno = _h_errno,
365         };
366
367         struct msghdr mh = {};
368         struct iovec iov[2];
369         size_t l;
370
371         assert(out_fd >= 0);
372
373         l = ret > 0 ? (size_t) ret : 0;
374
375         resp.header.length = sizeof(ResResponse) + l;
376
377         iov[0] = (struct iovec) { .iov_base = &resp, .iov_len = sizeof(ResResponse) };
378         iov[1] = (struct iovec) { .iov_base = (void*) answer, .iov_len = l };
379
380         mh.msg_iov = iov;
381         mh.msg_iovlen = ELEMENTSOF(iov);
382
383         if (sendmsg(out_fd, &mh, MSG_NOSIGNAL) < 0)
384                 return -errno;
385
386         return 0;
387 }
388
389 static int handle_request(int out_fd, const Packet *packet, size_t length) {
390         const RHeader *req;
391
392         assert(out_fd >= 0);
393         assert(packet);
394
395         req = &packet->rheader;
396
397         assert(length >= sizeof(RHeader));
398         assert(length == req->length);
399
400         switch (req->type) {
401
402         case REQUEST_ADDRINFO: {
403                const AddrInfoRequest *ai_req = &packet->addrinfo_request;
404                struct addrinfo hints = {}, *result = NULL;
405                const char *node, *service;
406                int ret;
407
408                assert(length >= sizeof(AddrInfoRequest));
409                assert(length == sizeof(AddrInfoRequest) + ai_req->node_len + ai_req->service_len);
410
411                hints.ai_flags = ai_req->ai_flags;
412                hints.ai_family = ai_req->ai_family;
413                hints.ai_socktype = ai_req->ai_socktype;
414                hints.ai_protocol = ai_req->ai_protocol;
415
416                node = ai_req->node_len ? (const char*) ai_req + sizeof(AddrInfoRequest) : NULL;
417                service = ai_req->service_len ? (const char*) ai_req + sizeof(AddrInfoRequest) + ai_req->node_len : NULL;
418
419                ret = getaddrinfo(
420                                node, service,
421                                ai_req->hints_valid ? &hints : NULL,
422                                &result);
423
424                /* send_addrinfo_reply() frees result */
425                return send_addrinfo_reply(out_fd, req->id, ret, result, errno, h_errno);
426         }
427
428         case REQUEST_NAMEINFO: {
429                const NameInfoRequest *ni_req = &packet->nameinfo_request;
430                char hostbuf[NI_MAXHOST], servbuf[NI_MAXSERV];
431                union sockaddr_union sa;
432                int ret;
433
434                assert(length >= sizeof(NameInfoRequest));
435                assert(length == sizeof(NameInfoRequest) + ni_req->sockaddr_len);
436                assert(sizeof(sa) >= ni_req->sockaddr_len);
437
438                memcpy(&sa, (const uint8_t *) ni_req + sizeof(NameInfoRequest), ni_req->sockaddr_len);
439
440                ret = getnameinfo(&sa.sa, ni_req->sockaddr_len,
441                                ni_req->gethost ? hostbuf : NULL, ni_req->gethost ? sizeof(hostbuf) : 0,
442                                ni_req->getserv ? servbuf : NULL, ni_req->getserv ? sizeof(servbuf) : 0,
443                                ni_req->flags);
444
445                return send_nameinfo_reply(out_fd, req->id, ret,
446                                ret == 0 && ni_req->gethost ? hostbuf : NULL,
447                                ret == 0 && ni_req->getserv ? servbuf : NULL,
448                                errno, h_errno);
449         }
450
451         case REQUEST_RES_QUERY:
452         case REQUEST_RES_SEARCH: {
453                  const ResRequest *res_req = &packet->res_request;
454                  union {
455                          HEADER header;
456                          uint8_t space[BUFSIZE];
457                  } answer;
458                  const char *dname;
459                  int ret;
460
461                  assert(length >= sizeof(ResRequest));
462                  assert(length == sizeof(ResRequest) + res_req->dname_len);
463
464                  dname = (const char *) req + sizeof(ResRequest);
465
466                  if (req->type == REQUEST_RES_QUERY)
467                          ret = res_query(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
468                  else
469                          ret = res_search(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
470
471                  return send_res_reply(out_fd, req->id, (unsigned char *) &answer, ret, errno, h_errno);
472         }
473
474         case REQUEST_TERMINATE:
475                  /* Quit */
476                  return -ECONNRESET;
477
478         default:
479                 assert_not_reached("Unknown request");
480         }
481
482         return 0;
483 }
484
485 static void* thread_worker(void *p) {
486         sd_resolve *resolve = p;
487         sigset_t fullset;
488
489         /* No signals in this thread please */
490         assert_se(sigfillset(&fullset) == 0);
491         assert_se(pthread_sigmask(SIG_BLOCK, &fullset, NULL) == 0);
492
493         /* Assign a pretty name to this thread */
494         prctl(PR_SET_NAME, (unsigned long) "sd-resolve");
495
496         while (!resolve->dead) {
497                 union {
498                         Packet packet;
499                         uint8_t space[BUFSIZE];
500                 } buf;
501                 ssize_t length;
502
503                 length = recv(resolve->fds[REQUEST_RECV_FD], &buf, sizeof(buf), 0);
504                 if (length < 0) {
505                         if (errno == EINTR)
506                                 continue;
507
508                         break;
509                 }
510                 if (length == 0)
511                         break;
512
513                 if (resolve->dead)
514                         break;
515
516                 if (handle_request(resolve->fds[RESPONSE_SEND_FD], &buf.packet, (size_t) length) < 0)
517                         break;
518         }
519
520         send_died(resolve->fds[RESPONSE_SEND_FD]);
521
522         return NULL;
523 }
524
525 static int start_threads(sd_resolve *resolve, unsigned extra) {
526         unsigned n;
527         int r;
528
529         n = resolve->n_queries + extra - resolve->n_done;
530         n = CLAMP(n, WORKERS_MIN, WORKERS_MAX);
531
532         while (resolve->n_valid_workers < n) {
533
534                 r = pthread_create(&resolve->workers[resolve->n_valid_workers], NULL, thread_worker, resolve);
535                 if (r != 0)
536                         return -r;
537
538                 resolve->n_valid_workers ++;
539         }
540
541         return 0;
542 }
543
544 static bool resolve_pid_changed(sd_resolve *r) {
545         assert(r);
546
547         /* We don't support people creating a resolver and keeping it
548          * around after fork(). Let's complain. */
549
550         return r->original_pid != getpid();
551 }
552
553 _public_ int sd_resolve_new(sd_resolve **ret) {
554         sd_resolve *resolve = NULL;
555         int i, r;
556
557         assert_return(ret, -EINVAL);
558
559         resolve = new0(sd_resolve, 1);
560         if (!resolve)
561                 return -ENOMEM;
562
563         resolve->n_ref = 1;
564         resolve->original_pid = getpid();
565
566         for (i = 0; i < _FD_MAX; i++)
567                 resolve->fds[i] = -1;
568
569         r = socketpair(PF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0, resolve->fds + REQUEST_RECV_FD);
570         if (r < 0) {
571                 r = -errno;
572                 goto fail;
573         }
574
575         r = socketpair(PF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0, resolve->fds + RESPONSE_RECV_FD);
576         if (r < 0) {
577                 r = -errno;
578                 goto fail;
579         }
580
581         fd_inc_sndbuf(resolve->fds[REQUEST_SEND_FD], QUERIES_MAX * BUFSIZE);
582         fd_inc_rcvbuf(resolve->fds[REQUEST_RECV_FD], QUERIES_MAX * BUFSIZE);
583         fd_inc_sndbuf(resolve->fds[RESPONSE_SEND_FD], QUERIES_MAX * BUFSIZE);
584         fd_inc_rcvbuf(resolve->fds[RESPONSE_RECV_FD], QUERIES_MAX * BUFSIZE);
585
586         fd_nonblock(resolve->fds[RESPONSE_RECV_FD], true);
587
588         *ret = resolve;
589         return 0;
590
591 fail:
592         sd_resolve_unref(resolve);
593         return r;
594 }
595
596 _public_ int sd_resolve_default(sd_resolve **ret) {
597
598         static thread_local sd_resolve *default_resolve = NULL;
599         sd_resolve *e = NULL;
600         int r;
601
602         if (!ret)
603                 return !!default_resolve;
604
605         if (default_resolve) {
606                 *ret = sd_resolve_ref(default_resolve);
607                 return 0;
608         }
609
610         r = sd_resolve_new(&e);
611         if (r < 0)
612                 return r;
613
614         e->default_resolve_ptr = &default_resolve;
615         e->tid = gettid();
616         default_resolve = e;
617
618         *ret = e;
619         return 1;
620 }
621
622 _public_ int sd_resolve_get_tid(sd_resolve *resolve, pid_t *tid) {
623         assert_return(resolve, -EINVAL);
624         assert_return(tid, -EINVAL);
625         assert_return(!resolve_pid_changed(resolve), -ECHILD);
626
627         if (resolve->tid != 0) {
628                 *tid = resolve->tid;
629                 return 0;
630         }
631
632         if (resolve->event)
633                 return sd_event_get_tid(resolve->event, tid);
634
635         return -ENXIO;
636 }
637
638 static void resolve_free(sd_resolve *resolve) {
639         PROTECT_ERRNO;
640         sd_resolve_query *q;
641         unsigned i;
642
643         assert(resolve);
644
645         while ((q = resolve->queries)) {
646                 assert(q->floating);
647                 resolve_query_disconnect(q);
648                 sd_resolve_query_unref(q);
649         }
650
651         if (resolve->default_resolve_ptr)
652                 *(resolve->default_resolve_ptr) = NULL;
653
654         resolve->dead = true;
655
656         sd_resolve_detach_event(resolve);
657
658         if (resolve->fds[REQUEST_SEND_FD] >= 0) {
659
660                 RHeader req = {
661                         .type = REQUEST_TERMINATE,
662                         .length = sizeof(req)
663                 };
664
665                 /* Send one termination packet for each worker */
666                 for (i = 0; i < resolve->n_valid_workers; i++)
667                         send(resolve->fds[REQUEST_SEND_FD], &req, req.length, MSG_NOSIGNAL);
668         }
669
670         /* Now terminate them and wait until they are gone. */
671         for (i = 0; i < resolve->n_valid_workers; i++) {
672                 for (;;) {
673                         if (pthread_join(resolve->workers[i], NULL) != EINTR)
674                                 break;
675                 }
676         }
677
678         /* Close all communication channels */
679         for (i = 0; i < _FD_MAX; i++)
680                 safe_close(resolve->fds[i]);
681
682         free(resolve);
683 }
684
685 _public_ sd_resolve* sd_resolve_ref(sd_resolve *resolve) {
686         assert_return(resolve, NULL);
687
688         assert(resolve->n_ref >= 1);
689         resolve->n_ref++;
690
691         return resolve;
692 }
693
694 _public_ sd_resolve* sd_resolve_unref(sd_resolve *resolve) {
695
696         if (!resolve)
697                 return NULL;
698
699         assert(resolve->n_ref >= 1);
700         resolve->n_ref--;
701
702         if (resolve->n_ref <= 0)
703                 resolve_free(resolve);
704
705         return NULL;
706 }
707
708 _public_ int sd_resolve_get_fd(sd_resolve *resolve) {
709         assert_return(resolve, -EINVAL);
710         assert_return(!resolve_pid_changed(resolve), -ECHILD);
711
712         return resolve->fds[RESPONSE_RECV_FD];
713 }
714
715 _public_ int sd_resolve_get_events(sd_resolve *resolve) {
716         assert_return(resolve, -EINVAL);
717         assert_return(!resolve_pid_changed(resolve), -ECHILD);
718
719         return resolve->n_queries > resolve->n_done ? POLLIN : 0;
720 }
721
722 _public_ int sd_resolve_get_timeout(sd_resolve *resolve, uint64_t *usec) {
723         assert_return(resolve, -EINVAL);
724         assert_return(usec, -EINVAL);
725         assert_return(!resolve_pid_changed(resolve), -ECHILD);
726
727         *usec = (uint64_t) -1;
728         return 0;
729 }
730
731 static sd_resolve_query *lookup_query(sd_resolve *resolve, unsigned id) {
732         sd_resolve_query *q;
733
734         assert(resolve);
735
736         q = resolve->query_array[id % QUERIES_MAX];
737         if (q)
738                 if (q->id == id)
739                         return q;
740
741         return NULL;
742 }
743
744 static int complete_query(sd_resolve *resolve, sd_resolve_query *q) {
745         int r;
746
747         assert(q);
748         assert(!q->done);
749         assert(q->resolve == resolve);
750
751         q->done = true;
752         resolve->n_done ++;
753
754         resolve->current = sd_resolve_query_ref(q);
755
756         switch (q->type) {
757
758         case REQUEST_ADDRINFO:
759                 r = getaddrinfo_done(q);
760                 break;
761
762         case REQUEST_NAMEINFO:
763                 r = getnameinfo_done(q);
764                 break;
765
766         case REQUEST_RES_QUERY:
767         case REQUEST_RES_SEARCH:
768                 r = res_query_done(q);
769                 break;
770
771         default:
772                 assert_not_reached("Cannot complete unknown query type");
773         }
774
775         resolve->current = sd_resolve_query_unref(q);
776
777         if (q->floating) {
778                 resolve_query_disconnect(q);
779                 sd_resolve_query_unref(q);
780         }
781
782         return r;
783 }
784
785 static int unserialize_addrinfo(const void **p, size_t *length, struct addrinfo **ret_ai) {
786         AddrInfoSerialization s;
787         size_t l;
788         struct addrinfo *ai;
789
790         assert(p);
791         assert(*p);
792         assert(ret_ai);
793         assert(length);
794
795         if (*length < sizeof(AddrInfoSerialization))
796                 return -EBADMSG;
797
798         memcpy(&s, *p, sizeof(s));
799
800         l = sizeof(AddrInfoSerialization) + s.ai_addrlen + s.canonname_len;
801         if (*length < l)
802                 return -EBADMSG;
803
804         ai = new0(struct addrinfo, 1);
805         if (!ai)
806                 return -ENOMEM;
807
808         ai->ai_flags = s.ai_flags;
809         ai->ai_family = s.ai_family;
810         ai->ai_socktype = s.ai_socktype;
811         ai->ai_protocol = s.ai_protocol;
812         ai->ai_addrlen = s.ai_addrlen;
813
814         if (s.ai_addrlen > 0) {
815                 ai->ai_addr = memdup((const uint8_t*) *p + sizeof(AddrInfoSerialization), s.ai_addrlen);
816                 if (!ai->ai_addr) {
817                         free(ai);
818                         return -ENOMEM;
819                 }
820         }
821
822         if (s.canonname_len > 0) {
823                 ai->ai_canonname = memdup((const uint8_t*) *p + sizeof(AddrInfoSerialization) + s.ai_addrlen, s.canonname_len);
824                 if (!ai->ai_canonname) {
825                         free(ai->ai_addr);
826                         free(ai);
827                         return -ENOMEM;
828                 }
829         }
830
831         *length -= l;
832         *ret_ai = ai;
833         *p = ((const uint8_t*) *p) + l;
834
835         return 0;
836 }
837
838 static int handle_response(sd_resolve *resolve, const Packet *packet, size_t length) {
839         const RHeader *resp;
840         sd_resolve_query *q;
841         int r;
842
843         assert(resolve);
844
845         resp = &packet->rheader;
846         assert(resp);
847         assert(length >= sizeof(RHeader));
848         assert(length == resp->length);
849
850         if (resp->type == RESPONSE_DIED) {
851                 resolve->dead = true;
852                 return 0;
853         }
854
855         q = lookup_query(resolve, resp->id);
856         if (!q)
857                 return 0;
858
859         switch (resp->type) {
860
861         case RESPONSE_ADDRINFO: {
862                 const AddrInfoResponse *ai_resp = &packet->addrinfo_response;
863                 const void *p;
864                 size_t l;
865                 struct addrinfo *prev = NULL;
866
867                 assert(length >= sizeof(AddrInfoResponse));
868                 assert(q->type == REQUEST_ADDRINFO);
869
870                 q->ret = ai_resp->ret;
871                 q->_errno = ai_resp->_errno;
872                 q->_h_errno = ai_resp->_h_errno;
873
874                 l = length - sizeof(AddrInfoResponse);
875                 p = (const uint8_t*) resp + sizeof(AddrInfoResponse);
876
877                 while (l > 0 && p) {
878                         struct addrinfo *ai = NULL;
879
880                         r = unserialize_addrinfo(&p, &l, &ai);
881                         if (r < 0) {
882                                 q->ret = EAI_SYSTEM;
883                                 q->_errno = -r;
884                                 q->_h_errno = 0;
885                                 freeaddrinfo(q->addrinfo);
886                                 q->addrinfo = NULL;
887                                 break;
888                         }
889
890                         if (prev)
891                                 prev->ai_next = ai;
892                         else
893                                 q->addrinfo = ai;
894
895                         prev = ai;
896                 }
897
898                 return complete_query(resolve, q);
899         }
900
901         case RESPONSE_NAMEINFO: {
902                 const NameInfoResponse *ni_resp = &packet->nameinfo_response;
903
904                 assert(length >= sizeof(NameInfoResponse));
905                 assert(q->type == REQUEST_NAMEINFO);
906
907                 q->ret = ni_resp->ret;
908                 q->_errno = ni_resp->_errno;
909                 q->_h_errno = ni_resp->_h_errno;
910
911                 if (ni_resp->hostlen > 0) {
912                         q->host = strndup((const char*) ni_resp + sizeof(NameInfoResponse), ni_resp->hostlen-1);
913                         if (!q->host) {
914                                 q->ret = EAI_MEMORY;
915                                 q->_errno = ENOMEM;
916                                 q->_h_errno = 0;
917                         }
918                 }
919
920                 if (ni_resp->servlen > 0) {
921                         q->serv = strndup((const char*) ni_resp + sizeof(NameInfoResponse) + ni_resp->hostlen, ni_resp->servlen-1);
922                         if (!q->serv) {
923                                 q->ret = EAI_MEMORY;
924                                 q->_errno = ENOMEM;
925                                 q->_h_errno = 0;
926                         }
927                 }
928
929                 return complete_query(resolve, q);
930         }
931
932         case RESPONSE_RES: {
933                 const ResResponse *res_resp = &packet->res_response;
934
935                 assert(length >= sizeof(ResResponse));
936                 assert(q->type == REQUEST_RES_QUERY || q->type == REQUEST_RES_SEARCH);
937
938                 q->ret = res_resp->ret;
939                 q->_errno = res_resp->_errno;
940                 q->_h_errno = res_resp->_h_errno;
941
942                 if (res_resp->ret >= 0)  {
943                         q->answer = memdup((const char *)resp + sizeof(ResResponse), res_resp->ret);
944                         if (!q->answer) {
945                                 q->ret = -1;
946                                 q->_errno = ENOMEM;
947                                 q->_h_errno = 0;
948                         }
949                 }
950
951                 return complete_query(resolve, q);
952         }
953
954         default:
955                 return 0;
956         }
957 }
958
959 _public_ int sd_resolve_process(sd_resolve *resolve) {
960         RESOLVE_DONT_DESTROY(resolve);
961
962         union {
963                 Packet packet;
964                 uint8_t space[BUFSIZE];
965         } buf;
966         ssize_t l;
967         int r;
968
969         assert_return(resolve, -EINVAL);
970         assert_return(!resolve_pid_changed(resolve), -ECHILD);
971
972         /* We don't allow recursively invoking sd_resolve_process(). */
973         assert_return(!resolve->current, -EBUSY);
974
975         l = recv(resolve->fds[RESPONSE_RECV_FD], &buf, sizeof(buf), 0);
976         if (l < 0) {
977                 if (errno == EAGAIN)
978                         return 0;
979
980                 return -errno;
981         }
982         if (l == 0)
983                 return -ECONNREFUSED;
984
985         r = handle_response(resolve, &buf.packet, (size_t) l);
986         if (r < 0)
987                 return r;
988
989         return 1;
990 }
991
992 _public_ int sd_resolve_wait(sd_resolve *resolve, uint64_t timeout_usec) {
993         int r;
994
995         assert_return(resolve, -EINVAL);
996         assert_return(!resolve_pid_changed(resolve), -ECHILD);
997
998         if (resolve->n_done >= resolve->n_queries)
999                 return 0;
1000
1001         do {
1002                 r = fd_wait_for_event(resolve->fds[RESPONSE_RECV_FD], POLLIN, timeout_usec);
1003         } while (r == -EINTR);
1004
1005         if (r < 0)
1006                 return r;
1007
1008         return sd_resolve_process(resolve);
1009 }
1010
1011 static int alloc_query(sd_resolve *resolve, bool floating, sd_resolve_query **_q) {
1012         sd_resolve_query *q;
1013         int r;
1014
1015         assert(resolve);
1016         assert(_q);
1017
1018         if (resolve->n_queries >= QUERIES_MAX)
1019                 return -ENOBUFS;
1020
1021         r = start_threads(resolve, 1);
1022         if (r < 0)
1023                 return r;
1024
1025         while (resolve->query_array[resolve->current_index]) {
1026                 resolve->current_index++;
1027                 resolve->current_id++;
1028
1029                 resolve->current_index %= QUERIES_MAX;
1030         }
1031
1032         q = resolve->query_array[resolve->current_index] = new0(sd_resolve_query, 1);
1033         if (!q)
1034                 return -ENOMEM;
1035
1036         q->n_ref = 1;
1037         q->resolve = resolve;
1038         q->floating = floating;
1039         q->id = resolve->current_id;
1040
1041         if (!floating)
1042                 sd_resolve_ref(resolve);
1043
1044         LIST_PREPEND(queries, resolve->queries, q);
1045         resolve->n_queries++;
1046
1047         *_q = q;
1048         return 0;
1049 }
1050
1051 _public_ int sd_resolve_getaddrinfo(
1052                 sd_resolve *resolve,
1053                 sd_resolve_query **_q,
1054                 const char *node, const char *service,
1055                 const struct addrinfo *hints,
1056                 sd_resolve_getaddrinfo_handler_t callback, void *userdata) {
1057
1058         AddrInfoRequest req = {};
1059         struct msghdr mh = {};
1060         struct iovec iov[3];
1061         sd_resolve_query *q;
1062         int r;
1063
1064         assert_return(resolve, -EINVAL);
1065         assert_return(node || service, -EINVAL);
1066         assert_return(callback, -EINVAL);
1067         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1068
1069         r = alloc_query(resolve, !_q, &q);
1070         if (r < 0)
1071                 return r;
1072
1073         q->type = REQUEST_ADDRINFO;
1074         q->getaddrinfo_handler = callback;
1075         q->userdata = userdata;
1076
1077         req.node_len = node ? strlen(node)+1 : 0;
1078         req.service_len = service ? strlen(service)+1 : 0;
1079
1080         req.header.id = q->id;
1081         req.header.type = REQUEST_ADDRINFO;
1082         req.header.length = sizeof(AddrInfoRequest) + req.node_len + req.service_len;
1083
1084         if (hints) {
1085                 req.hints_valid = true;
1086                 req.ai_flags = hints->ai_flags;
1087                 req.ai_family = hints->ai_family;
1088                 req.ai_socktype = hints->ai_socktype;
1089                 req.ai_protocol = hints->ai_protocol;
1090         }
1091
1092         iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(AddrInfoRequest) };
1093         if (node)
1094                 iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = (void*) node, .iov_len = req.node_len };
1095         if (service)
1096                 iov[mh.msg_iovlen++] = (struct iovec) { .iov_base = (void*) service, .iov_len = req.service_len };
1097         mh.msg_iov = iov;
1098
1099         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1100                 sd_resolve_query_unref(q);
1101                 return -errno;
1102         }
1103
1104         if (_q)
1105                 *_q = q;
1106
1107         return 0;
1108 }
1109
1110 static int getaddrinfo_done(sd_resolve_query* q) {
1111         assert(q);
1112         assert(q->done);
1113         assert(q->getaddrinfo_handler);
1114
1115         errno = q->_errno;
1116         h_errno = q->_h_errno;
1117
1118         return q->getaddrinfo_handler(q, q->ret, q->addrinfo, q->userdata);
1119 }
1120
1121 _public_ int sd_resolve_getnameinfo(
1122                 sd_resolve *resolve,
1123                 sd_resolve_query**_q,
1124                 const struct sockaddr *sa, socklen_t salen,
1125                 int flags,
1126                 uint64_t get,
1127                 sd_resolve_getnameinfo_handler_t callback,
1128                 void *userdata) {
1129
1130         NameInfoRequest req = {};
1131         struct msghdr mh = {};
1132         struct iovec iov[2];
1133         sd_resolve_query *q;
1134         int r;
1135
1136         assert_return(resolve, -EINVAL);
1137         assert_return(sa, -EINVAL);
1138         assert_return(salen >= sizeof(struct sockaddr), -EINVAL);
1139         assert_return(salen <= sizeof(union sockaddr_union), -EINVAL);
1140         assert_return((get & ~SD_RESOLVE_GET_BOTH) == 0, -EINVAL);
1141         assert_return(callback, -EINVAL);
1142         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1143
1144         r = alloc_query(resolve, !_q, &q);
1145         if (r < 0)
1146                 return r;
1147
1148         q->type = REQUEST_NAMEINFO;
1149         q->getnameinfo_handler = callback;
1150         q->userdata = userdata;
1151
1152         req.header.id = q->id;
1153         req.header.type = REQUEST_NAMEINFO;
1154         req.header.length = sizeof(NameInfoRequest) + salen;
1155
1156         req.flags = flags;
1157         req.sockaddr_len = salen;
1158         req.gethost = !!(get & SD_RESOLVE_GET_HOST);
1159         req.getserv = !!(get & SD_RESOLVE_GET_SERVICE);
1160
1161         iov[0] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(NameInfoRequest) };
1162         iov[1] = (struct iovec) { .iov_base = (void*) sa, .iov_len = salen };
1163
1164         mh.msg_iov = iov;
1165         mh.msg_iovlen = 2;
1166
1167         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1168                 sd_resolve_query_unref(q);
1169                 return -errno;
1170         }
1171
1172         if (_q)
1173                 *_q = q;
1174
1175         return 0;
1176 }
1177
1178 static int getnameinfo_done(sd_resolve_query *q) {
1179
1180         assert(q);
1181         assert(q->done);
1182         assert(q->getnameinfo_handler);
1183
1184         errno = q->_errno;
1185         h_errno= q->_h_errno;
1186
1187         return q->getnameinfo_handler(q, q->ret, q->host, q->serv, q->userdata);
1188 }
1189
1190 static int resolve_res(
1191                 sd_resolve *resolve,
1192                 sd_resolve_query **_q,
1193                 QueryType qtype,
1194                 const char *dname,
1195                 int class, int type,
1196                 sd_resolve_res_handler_t callback, void *userdata) {
1197
1198         struct msghdr mh = {};
1199         struct iovec iov[2];
1200         ResRequest req = {};
1201         sd_resolve_query *q;
1202         int r;
1203
1204         assert_return(resolve, -EINVAL);
1205         assert_return(dname, -EINVAL);
1206         assert_return(callback, -EINVAL);
1207         assert_return(!resolve_pid_changed(resolve), -ECHILD);
1208
1209         r = alloc_query(resolve, !_q, &q);
1210         if (r < 0)
1211                 return r;
1212
1213         q->type = qtype;
1214         q->res_handler = callback;
1215         q->userdata = userdata;
1216
1217         req.dname_len = strlen(dname) + 1;
1218         req.class = class;
1219         req.type = type;
1220
1221         req.header.id = q->id;
1222         req.header.type = qtype;
1223         req.header.length = sizeof(ResRequest) + req.dname_len;
1224
1225         iov[0] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(ResRequest) };
1226         iov[1] = (struct iovec) { .iov_base = (void*) dname, .iov_len = req.dname_len };
1227
1228         mh.msg_iov = iov;
1229         mh.msg_iovlen = 2;
1230
1231         if (sendmsg(resolve->fds[REQUEST_SEND_FD], &mh, MSG_NOSIGNAL) < 0) {
1232                 sd_resolve_query_unref(q);
1233                 return -errno;
1234         }
1235
1236         if (_q)
1237                 *_q = q;
1238
1239         return 0;
1240 }
1241
1242 _public_ int sd_resolve_res_query(sd_resolve *resolve, sd_resolve_query** q, const char *dname, int class, int type, sd_resolve_res_handler_t callback, void *userdata) {
1243         return resolve_res(resolve, q, REQUEST_RES_QUERY, dname, class, type, callback, userdata);
1244 }
1245
1246 _public_ int sd_resolve_res_search(sd_resolve *resolve, sd_resolve_query** q, const char *dname, int class, int type, sd_resolve_res_handler_t callback, void *userdata) {
1247         return resolve_res(resolve, q, REQUEST_RES_SEARCH, dname, class, type, callback, userdata);
1248 }
1249
1250 static int res_query_done(sd_resolve_query* q) {
1251         assert(q);
1252         assert(q->done);
1253         assert(q->res_handler);
1254
1255         errno = q->_errno;
1256         h_errno = q->_h_errno;
1257
1258         return q->res_handler(q, q->ret, q->answer, q->userdata);
1259 }
1260
1261 _public_ sd_resolve_query* sd_resolve_query_ref(sd_resolve_query *q) {
1262         assert_return(q, NULL);
1263
1264         assert(q->n_ref >= 1);
1265         q->n_ref++;
1266
1267         return q;
1268 }
1269
1270 static void resolve_freeaddrinfo(struct addrinfo *ai) {
1271         while (ai) {
1272                 struct addrinfo *next = ai->ai_next;
1273
1274                 free(ai->ai_addr);
1275                 free(ai->ai_canonname);
1276                 free(ai);
1277                 ai = next;
1278         }
1279 }
1280
1281 static void resolve_query_disconnect(sd_resolve_query *q) {
1282         sd_resolve *resolve;
1283         unsigned i;
1284
1285         assert(q);
1286
1287         if (!q->resolve)
1288                 return;
1289
1290         resolve = q->resolve;
1291         assert(resolve->n_queries > 0);
1292
1293         if (q->done) {
1294                 assert(resolve->n_done > 0);
1295                 resolve->n_done--;
1296         }
1297
1298         i = q->id % QUERIES_MAX;
1299         assert(resolve->query_array[i] == q);
1300         resolve->query_array[i] = NULL;
1301         LIST_REMOVE(queries, resolve->queries, q);
1302         resolve->n_queries--;
1303
1304         q->resolve = NULL;
1305         if (!q->floating)
1306                 sd_resolve_unref(resolve);
1307 }
1308
1309 static void resolve_query_free(sd_resolve_query *q) {
1310         assert(q);
1311
1312         resolve_query_disconnect(q);
1313
1314         resolve_freeaddrinfo(q->addrinfo);
1315         free(q->host);
1316         free(q->serv);
1317         free(q->answer);
1318         free(q);
1319 }
1320
1321 _public_ sd_resolve_query* sd_resolve_query_unref(sd_resolve_query* q) {
1322         if (!q)
1323                 return NULL;
1324
1325         assert(q->n_ref >= 1);
1326         q->n_ref--;
1327
1328         if (q->n_ref <= 0)
1329                 resolve_query_free(q);
1330
1331         return NULL;
1332 }
1333
1334 _public_ int sd_resolve_query_is_done(sd_resolve_query *q) {
1335         assert_return(q, -EINVAL);
1336         assert_return(!resolve_pid_changed(q->resolve), -ECHILD);
1337
1338         return q->done;
1339 }
1340
1341 _public_ void* sd_resolve_query_set_userdata(sd_resolve_query *q, void *userdata) {
1342         void *ret;
1343
1344         assert_return(q, NULL);
1345         assert_return(!resolve_pid_changed(q->resolve), NULL);
1346
1347         ret = q->userdata;
1348         q->userdata = userdata;
1349
1350         return ret;
1351 }
1352
1353 _public_ void* sd_resolve_query_get_userdata(sd_resolve_query *q) {
1354         assert_return(q, NULL);
1355         assert_return(!resolve_pid_changed(q->resolve), NULL);
1356
1357         return q->userdata;
1358 }
1359
1360 _public_ sd_resolve *sd_resolve_query_get_resolve(sd_resolve_query *q) {
1361         assert_return(q, NULL);
1362         assert_return(!resolve_pid_changed(q->resolve), NULL);
1363
1364         return q->resolve;
1365 }
1366
1367 static int io_callback(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
1368         sd_resolve *resolve = userdata;
1369         int r;
1370
1371         assert(resolve);
1372
1373         r = sd_resolve_process(resolve);
1374         if (r < 0)
1375                 return r;
1376
1377         return 1;
1378 }
1379
1380 _public_ int sd_resolve_attach_event(sd_resolve *resolve, sd_event *event, int priority) {
1381         int r;
1382
1383         assert_return(resolve, -EINVAL);
1384         assert_return(!resolve->event, -EBUSY);
1385
1386         assert(!resolve->event_source);
1387
1388         if (event)
1389                 resolve->event = sd_event_ref(event);
1390         else {
1391                 r = sd_event_default(&resolve->event);
1392                 if (r < 0)
1393                         return r;
1394         }
1395
1396         r = sd_event_add_io(resolve->event, &resolve->event_source, resolve->fds[RESPONSE_RECV_FD], POLLIN, io_callback, resolve);
1397         if (r < 0)
1398                 goto fail;
1399
1400         r = sd_event_source_set_priority(resolve->event_source, priority);
1401         if (r < 0)
1402                 goto fail;
1403
1404         return 0;
1405
1406 fail:
1407         sd_resolve_detach_event(resolve);
1408         return r;
1409 }
1410
1411 _public_  int sd_resolve_detach_event(sd_resolve *resolve) {
1412         assert_return(resolve, -EINVAL);
1413
1414         if (!resolve->event)
1415                 return 0;
1416
1417         if (resolve->event_source) {
1418                 sd_event_source_set_enabled(resolve->event_source, SD_EVENT_OFF);
1419                 resolve->event_source = sd_event_source_unref(resolve->event_source);
1420         }
1421
1422         resolve->event = sd_event_unref(resolve->event);
1423         return 1;
1424 }
1425
1426 _public_ sd_event *sd_resolve_get_event(sd_resolve *resolve) {
1427         assert_return(resolve, NULL);
1428
1429         return resolve->event;
1430 }