chiark / gitweb /
rtnl: add call_async and call_async_cancel
[elogind.git] / src / libsystemd-rtnl / sd-rtnl.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 Tom Gundersen <teg@jklm.no>
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <sys/socket.h>
23 #include <poll.h>
24
25 #include "macro.h"
26 #include "util.h"
27 #include "hashmap.h"
28
29 #include "sd-rtnl.h"
30 #include "rtnl-internal.h"
31 #include "rtnl-util.h"
32
33 static int sd_rtnl_new(sd_rtnl **ret) {
34         sd_rtnl *rtnl;
35
36         assert_return(ret, -EINVAL);
37
38         rtnl = new0(sd_rtnl, 1);
39         if (!rtnl)
40                 return -ENOMEM;
41
42         rtnl->n_ref = REFCNT_INIT;
43
44         rtnl->fd = -1;
45
46         rtnl->sockaddr.nl.nl_family = AF_NETLINK;
47
48         rtnl->original_pid = getpid();
49
50         /* We guarantee that wqueue always has space for at least
51          * one entry */
52         rtnl->wqueue = new(sd_rtnl_message*, 1);
53         if (!rtnl->wqueue) {
54                 free(rtnl);
55                 return -ENOMEM;
56         }
57
58         *ret = rtnl;
59         return 0;
60 }
61
62 static bool rtnl_pid_changed(sd_rtnl *rtnl) {
63         assert(rtnl);
64
65         /* We don't support people creating an rtnl connection and
66          * keeping it around over a fork(). Let's complain. */
67
68         return rtnl->original_pid != getpid();
69 }
70
71 int sd_rtnl_open(uint32_t groups, sd_rtnl **ret) {
72         _cleanup_sd_rtnl_unref_ sd_rtnl *rtnl = NULL;
73         socklen_t addrlen;
74         int r;
75
76         r = sd_rtnl_new(&rtnl);
77         if (r < 0)
78                 return r;
79
80         rtnl->fd = socket(PF_NETLINK, SOCK_RAW|SOCK_CLOEXEC|SOCK_NONBLOCK, NETLINK_ROUTE);
81         if (rtnl->fd < 0)
82                 return -errno;
83
84         rtnl->sockaddr.nl.nl_groups = groups;
85
86         addrlen = sizeof(rtnl->sockaddr);
87
88         r = bind(rtnl->fd, &rtnl->sockaddr.sa, addrlen);
89         if (r < 0)
90                 return -errno;
91
92         r = getsockname(rtnl->fd, &rtnl->sockaddr.sa, &addrlen);
93         if (r < 0)
94                 return r;
95
96         *ret = rtnl;
97         rtnl = NULL;
98
99         return 0;
100 }
101
102 sd_rtnl *sd_rtnl_ref(sd_rtnl *rtnl) {
103         if (rtnl)
104                 assert_se(REFCNT_INC(rtnl->n_ref) >= 2);
105
106         return rtnl;
107 }
108
109 sd_rtnl *sd_rtnl_unref(sd_rtnl *rtnl) {
110
111         if (rtnl && REFCNT_DEC(rtnl->n_ref) <= 0) {
112                 unsigned i;
113
114                 for (i = 0; i < rtnl->rqueue_size; i++)
115                         sd_rtnl_message_unref(rtnl->rqueue[i]);
116                 free(rtnl->rqueue);
117
118                 for (i = 0; i < rtnl->wqueue_size; i++)
119                         sd_rtnl_message_unref(rtnl->wqueue[i]);
120                 free(rtnl->wqueue);
121
122                 hashmap_free_free(rtnl->reply_callbacks);
123                 prioq_free(rtnl->reply_callbacks_prioq);
124
125                 if (rtnl->fd >= 0)
126                         close_nointr_nofail(rtnl->fd);
127
128                 free(rtnl);
129         }
130
131         return NULL;
132 }
133
134 int sd_rtnl_send(sd_rtnl *nl,
135                  sd_rtnl_message *message,
136                  uint32_t *serial) {
137         int r;
138
139         assert_return(nl, -EINVAL);
140         assert_return(!rtnl_pid_changed(nl), -ECHILD);
141         assert_return(message, -EINVAL);
142
143         r = message_seal(nl, message);
144         if (r < 0)
145                 return r;
146
147         if (nl->wqueue_size <= 0) {
148                 /* send directly */
149                 r = socket_write_message(nl, message);
150                 if (r < 0)
151                         return r;
152                 else if (r == 0) {
153                         /* nothing was sent, so let's put it on
154                          * the queue */
155                         nl->wqueue[0] = sd_rtnl_message_ref(message);
156                         nl->wqueue_size = 1;
157                 }
158         } else {
159                 sd_rtnl_message **q;
160
161                 /* append to queue */
162                 if (nl->wqueue_size >= RTNL_WQUEUE_MAX)
163                         return -ENOBUFS;
164
165                 q = realloc(nl->wqueue, sizeof(sd_rtnl_message*) * (nl->wqueue_size + 1));
166                 if (!q)
167                         return -ENOMEM;
168
169                 nl->wqueue = q;
170                 q[nl->wqueue_size ++] = sd_rtnl_message_ref(message);
171         }
172
173         if (serial)
174                 *serial = message_get_serial(message);
175
176         return 1;
177 }
178
179 static int dispatch_rqueue(sd_rtnl *rtnl, sd_rtnl_message **message) {
180         sd_rtnl_message *z = NULL;
181         int r;
182
183         assert(rtnl);
184         assert(message);
185
186         if (rtnl->rqueue_size > 0) {
187                 /* Dispatch a queued message */
188
189                 *message = rtnl->rqueue[0];
190                 rtnl->rqueue_size --;
191                 memmove(rtnl->rqueue, rtnl->rqueue + 1, sizeof(sd_rtnl_message*) * rtnl->rqueue_size);
192
193                 return 1;
194         }
195
196         /* Try to read a new message */
197         r = socket_read_message(rtnl, &z);
198         if (r < 0)
199                 return r;
200         if (r == 0)
201                 return 0;
202
203         *message = z;
204
205         return 1;
206 }
207
208 static int dispatch_wqueue(sd_rtnl *rtnl) {
209         int r, ret = 0;
210
211         assert(rtnl);
212
213         while (rtnl->wqueue_size > 0) {
214                 r = socket_write_message(rtnl, rtnl->wqueue[0]);
215                 if (r < 0)
216                         return r;
217                 else if (r == 0)
218                         /* Didn't do anything this time */
219                         return ret;
220                 else {
221                         /* see equivalent in sd-bus.c */
222                         sd_rtnl_message_unref(rtnl->wqueue[0]);
223                         rtnl->wqueue_size --;
224                         memmove(rtnl->wqueue, rtnl->wqueue + 1, sizeof(sd_rtnl_message*) * rtnl->wqueue_size);
225
226                         ret = 1;
227                 }
228         }
229
230         return ret;
231 }
232
233 static int process_timeout(sd_rtnl *rtnl) {
234         _cleanup_sd_rtnl_message_unref_ sd_rtnl_message *m = NULL;
235         struct reply_callback *c;
236         usec_t n;
237         int r;
238
239         assert(rtnl);
240
241         c = prioq_peek(rtnl->reply_callbacks_prioq);
242         if (!c)
243                 return 0;
244
245         n = now(CLOCK_MONOTONIC);
246         if (c->timeout > n)
247                 return 0;
248
249         r = message_new_synthetic_error(-ETIMEDOUT, c->serial, &m);
250         if (r < 0)
251                 return r;
252
253         assert_se(prioq_pop(rtnl->reply_callbacks_prioq) == c);
254         hashmap_remove(rtnl->reply_callbacks, &c->serial);
255
256         r = c->callback(rtnl, m, c->userdata);
257         free(c);
258
259         return r < 0 ? r : 1;
260 }
261
262 static int process_reply(sd_rtnl *rtnl, sd_rtnl_message *m) {
263         struct reply_callback *c;
264         uint64_t serial;
265         int r;
266
267         assert(rtnl);
268         assert(m);
269
270         serial = message_get_serial(m);
271         c = hashmap_remove(rtnl->reply_callbacks, &serial);
272         if (!c)
273                 return 0;
274
275         if (c->timeout != 0)
276                 prioq_remove(rtnl->reply_callbacks_prioq, c, &c->prioq_idx);
277
278         r = c->callback(rtnl, m, c->userdata);
279         free(c);
280
281         return r;
282 }
283
284 static int process_running(sd_rtnl *rtnl, sd_rtnl_message **ret) {
285         _cleanup_sd_rtnl_message_unref_ sd_rtnl_message *m = NULL;
286         int r;
287
288         r = process_timeout(rtnl);
289         if (r != 0)
290                 goto null_message;
291
292         r = dispatch_wqueue(rtnl);
293         if (r != 0)
294                 goto null_message;
295
296         r = dispatch_rqueue(rtnl, &m);
297         if (r < 0)
298                 return r;
299         if (!m)
300                 goto null_message;
301
302         r = process_reply(rtnl, m);
303         if (r != 0)
304                 goto null_message;
305
306         if (ret) {
307                 *ret = m;
308                 m = NULL;
309
310                 return 1;
311         }
312
313         return 1;
314
315 null_message:
316         if (r >= 0 && ret)
317                 *ret = NULL;
318
319         return r;
320 }
321
322 int sd_rtnl_process(sd_rtnl *rtnl, sd_rtnl_message **ret) {
323         RTNL_DONT_DESTROY(rtnl);
324         int r;
325
326         assert_return(rtnl, -EINVAL);
327         assert_return(!rtnl_pid_changed(rtnl), -ECHILD);
328         assert_return(!rtnl->processing, -EBUSY);
329
330         rtnl->processing = true;
331         r = process_running(rtnl, ret);
332         rtnl->processing = false;
333
334         return r;
335 }
336
337 static usec_t calc_elapse(uint64_t usec) {
338         if (usec == (uint64_t) -1)
339                 return 0;
340
341         if (usec == 0)
342                 usec = RTNL_DEFAULT_TIMEOUT;
343
344         return now(CLOCK_MONOTONIC) + usec;
345 }
346
347 static int rtnl_poll(sd_rtnl *nl, uint64_t timeout_usec) {
348         struct pollfd p[1] = {};
349         struct timespec ts;
350         int r;
351
352         assert(nl);
353
354         p[0].fd = nl->fd;
355         p[0].events = POLLIN;
356
357         r = ppoll(p, 1, timeout_usec == (uint64_t) -1 ? NULL :
358                         timespec_store(&ts, timeout_usec), NULL);
359         if (r < 0)
360                 return -errno;
361
362         return r > 0 ? 1 : 0;
363 }
364
365 int sd_rtnl_wait(sd_rtnl *nl, uint64_t timeout_usec) {
366         assert_return(nl, -EINVAL);
367         assert_return(!rtnl_pid_changed(nl), -ECHILD);
368
369         if (nl->rqueue_size > 0)
370                 return 0;
371
372         return rtnl_poll(nl, timeout_usec);
373 }
374
375 static int timeout_compare(const void *a, const void *b) {
376         const struct reply_callback *x = a, *y = b;
377
378         if (x->timeout != 0 && y->timeout == 0)
379                 return -1;
380
381         if (x->timeout == 0 && y->timeout != 0)
382                 return 1;
383
384         if (x->timeout < y->timeout)
385                 return -1;
386
387         if (x->timeout > y->timeout)
388                 return 1;
389
390         return 0;
391 }
392
393 int sd_rtnl_call_async(sd_rtnl *nl,
394                        sd_rtnl_message *m,
395                        sd_rtnl_message_handler_t callback,
396                        void *userdata,
397                        uint64_t usec,
398                        uint32_t *serial) {
399         struct reply_callback *c;
400         uint32_t s;
401         int r, k;
402
403         assert_return(nl, -EINVAL);
404         assert_return(m, -EINVAL);
405         assert_return(callback, -EINVAL);
406         assert_return(!rtnl_pid_changed(nl), -ECHILD);
407
408         r = hashmap_ensure_allocated(&nl->reply_callbacks, uint64_hash_func, uint64_compare_func);
409         if (r < 0)
410                 return r;
411
412         if (usec != (uint64_t) -1) {
413                 r = prioq_ensure_allocated(&nl->reply_callbacks_prioq, timeout_compare);
414                 if (r < 0)
415                         return r;
416         }
417
418         c = new0(struct reply_callback, 1);
419         if (!c)
420                 return -ENOMEM;
421
422         c->callback = callback;
423         c->userdata = userdata;
424         c->timeout = calc_elapse(usec);
425
426         k = sd_rtnl_send(nl, m, &s);
427         if (k < 0) {
428                 free(c);
429                 return k;
430         }
431
432         c->serial = s;
433
434         r = hashmap_put(nl->reply_callbacks, &c->serial, c);
435         if (r < 0) {
436                 free(c);
437                 return r;
438         }
439
440         if (c->timeout != 0) {
441                 r = prioq_put(nl->reply_callbacks_prioq, c, &c->prioq_idx);
442                 if (r > 0) {
443                         c->timeout = 0;
444                         sd_rtnl_call_async_cancel(nl, c->serial);
445                         return r;
446                 }
447         }
448
449         if (serial)
450                 *serial = s;
451
452         return k;
453 }
454
455 int sd_rtnl_call_async_cancel(sd_rtnl *nl, uint32_t serial) {
456         struct reply_callback *c;
457         uint64_t s = serial;
458
459         assert_return(nl, -EINVAL);
460         assert_return(serial != 0, -EINVAL);
461         assert_return(!rtnl_pid_changed(nl), -ECHILD);
462
463         c = hashmap_remove(nl->reply_callbacks, &s);
464         if (!c)
465                 return 0;
466
467         if (c->timeout != 0)
468                 prioq_remove(nl->reply_callbacks_prioq, c, &c->prioq_idx);
469
470         free(c);
471         return 1;
472 }
473
474 int sd_rtnl_call(sd_rtnl *nl,
475                 sd_rtnl_message *message,
476                 uint64_t usec,
477                 sd_rtnl_message **ret) {
478         usec_t timeout;
479         uint32_t serial;
480         bool room = false;
481         int r;
482
483         assert_return(nl, -EINVAL);
484         assert_return(!rtnl_pid_changed(nl), -ECHILD);
485         assert_return(message, -EINVAL);
486
487         r = sd_rtnl_send(nl, message, &serial);
488         if (r < 0)
489                 return r;
490
491         timeout = calc_elapse(usec);
492
493         for (;;) {
494                 usec_t left;
495                 _cleanup_sd_rtnl_message_unref_ sd_rtnl_message *incoming = NULL;
496
497                 if (!room) {
498                         sd_rtnl_message **q;
499
500                         if (nl->rqueue_size >= RTNL_RQUEUE_MAX)
501                                 return -ENOBUFS;
502
503                         /* Make sure there's room for queueing this
504                          * locally, before we read the message */
505
506                         q = realloc(nl->rqueue, (nl->rqueue_size + 1) * sizeof(sd_rtnl_message*));
507                         if (!q)
508                                 return -ENOMEM;
509
510                         nl->rqueue = q;
511                         room = true;
512                 }
513
514                 r = socket_read_message(nl, &incoming);
515                 if (r < 0)
516                         return r;
517                 if (incoming) {
518                         uint32_t received_serial = message_get_serial(incoming);
519
520                         if (received_serial == serial) {
521                                 r = sd_rtnl_message_get_errno(incoming);
522                                 if (r < 0)
523                                         return r;
524
525                                 if (ret) {
526                                         *ret = incoming;
527                                         incoming = NULL;
528                                 }
529
530                                 return 1;
531                         }
532
533                         /* Room was allocated on the queue above */
534                         nl->rqueue[nl->rqueue_size ++] = incoming;
535                         incoming = NULL;
536                         room = false;
537
538                         /* Try to read more, right away */
539                         continue;
540                 }
541                 if (r != 0)
542                         continue;
543
544                 if (timeout > 0) {
545                         usec_t n;
546
547                         n = now(CLOCK_MONOTONIC);
548                         if (n >= timeout)
549                                 return -ETIMEDOUT;
550
551                         left = timeout - n;
552                 } else
553                         left = (uint64_t) -1;
554
555                 r = rtnl_poll(nl, left);
556                 if (r < 0)
557                         return r;
558
559                 r = dispatch_wqueue(nl);
560                 if (r < 0)
561                         return r;
562         }
563 }