chiark / gitweb /
pam_elogind compiling
[elogind.git] / src / libsystemd / sd-rtnl / sd-rtnl.c
index 816018a6c4c92ae3b4a85bbb6b89be62f0f59e87..40dea1252fcbf5327d97ccb5c8160094f99b3991 100644 (file)
@@ -22,6 +22,7 @@
 #include <sys/socket.h>
 #include <poll.h>
 
+#include "missing.h"
 #include "macro.h"
 #include "util.h"
 #include "hashmap.h"
@@ -54,6 +55,42 @@ static int sd_rtnl_new(sd_rtnl **ret) {
         if (!GREEDY_REALLOC(rtnl->wqueue, rtnl->wqueue_allocated, 1))
                 return -ENOMEM;
 
+        /* We guarantee that the read buffer has at least space for
+         * a message header */
+        if (!greedy_realloc((void**)&rtnl->rbuffer, &rtnl->rbuffer_allocated,
+                            sizeof(struct nlmsghdr), sizeof(uint8_t)))
+                return -ENOMEM;
+
+        /* Change notification responses have sequence 0, so we must
+         * start our request sequence numbers at 1, or we may confuse our
+         * responses with notifications from the kernel */
+        rtnl->serial = 1;
+
+        *ret = rtnl;
+        rtnl = NULL;
+
+        return 0;
+}
+
+int sd_rtnl_new_from_netlink(sd_rtnl **ret, int fd) {
+        _cleanup_rtnl_unref_ sd_rtnl *rtnl = NULL;
+        socklen_t addrlen;
+        int r;
+
+        assert_return(ret, -EINVAL);
+
+        r = sd_rtnl_new(&rtnl);
+        if (r < 0)
+                return r;
+
+        addrlen = sizeof(rtnl->sockaddr);
+
+        r = getsockname(fd, &rtnl->sockaddr.sa, &addrlen);
+        if (r < 0)
+                return -errno;
+
+        rtnl->fd = fd;
+
         *ret = rtnl;
         rtnl = NULL;
 
@@ -69,32 +106,60 @@ static bool rtnl_pid_changed(sd_rtnl *rtnl) {
         return rtnl->original_pid != getpid();
 }
 
-int sd_rtnl_open(sd_rtnl **ret, uint32_t groups) {
+static int rtnl_compute_groups_ap(uint32_t *_groups, unsigned n_groups, va_list ap) {
+        uint32_t groups = 0;
+        unsigned i;
+
+        for (i = 0; i < n_groups; i++) {
+                unsigned group;
+
+                group = va_arg(ap, unsigned);
+                assert_return(group < 32, -EINVAL);
+
+                groups |= group ? (1 << (group - 1)) : 0;
+        }
+
+        *_groups = groups;
+
+        return 0;
+}
+
+static int rtnl_open_fd_ap(sd_rtnl **ret, int fd, unsigned n_groups, va_list ap) {
         _cleanup_rtnl_unref_ sd_rtnl *rtnl = NULL;
         socklen_t addrlen;
-        int r;
+        int r, one = 1;
 
         assert_return(ret, -EINVAL);
+        assert_return(fd >= 0, -EINVAL);
 
         r = sd_rtnl_new(&rtnl);
         if (r < 0)
                 return r;
 
-        rtnl->fd = socket(PF_NETLINK, SOCK_RAW|SOCK_CLOEXEC|SOCK_NONBLOCK, NETLINK_ROUTE);
-        if (rtnl->fd < 0)
+        r = setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &one, sizeof(one));
+        if (r < 0)
+                return -errno;
+
+        r = setsockopt(fd, SOL_NETLINK, NETLINK_PKTINFO, &one, sizeof(one));
+        if (r < 0)
                 return -errno;
 
-        rtnl->sockaddr.nl.nl_groups = groups;
+        r = rtnl_compute_groups_ap(&rtnl->sockaddr.nl.nl_groups, n_groups, ap);
+        if (r < 0)
+                return r;
 
         addrlen = sizeof(rtnl->sockaddr);
 
-        r = bind(rtnl->fd, &rtnl->sockaddr.sa, addrlen);
-        if (r < 0)
+        r = bind(fd, &rtnl->sockaddr.sa, addrlen);
+        /* ignore EINVAL to allow opening an already bound socket */
+        if (r < 0 && errno != EINVAL)
                 return -errno;
 
-        r = getsockname(rtnl->fd, &rtnl->sockaddr.sa, &addrlen);
+        r = getsockname(fd, &rtnl->sockaddr.sa, &addrlen);
         if (r < 0)
-                return r;
+                return -errno;
+
+        rtnl->fd = fd;
 
         *ret = rtnl;
         rtnl = NULL;
@@ -102,6 +167,41 @@ int sd_rtnl_open(sd_rtnl **ret, uint32_t groups) {
         return 0;
 }
 
+int sd_rtnl_open_fd(sd_rtnl **ret, int fd, unsigned n_groups, ...) {
+        va_list ap;
+        int r;
+
+        va_start(ap, n_groups);
+        r = rtnl_open_fd_ap(ret, fd, n_groups, ap);
+        va_end(ap);
+
+        return r;
+}
+
+int sd_rtnl_open(sd_rtnl **ret, unsigned n_groups, ...) {
+        va_list ap;
+        int fd, r;
+
+        fd = socket(PF_NETLINK, SOCK_RAW|SOCK_CLOEXEC|SOCK_NONBLOCK, NETLINK_ROUTE);
+        if (fd < 0)
+                return -errno;
+
+        va_start(ap, n_groups);
+        r = rtnl_open_fd_ap(ret, fd, n_groups, ap);
+        va_end(ap);
+
+        if (r < 0) {
+                safe_close(fd);
+                return r;
+        }
+
+        return 0;
+}
+
+int sd_rtnl_inc_rcvbuf(const sd_rtnl *const rtnl, const int size) {
+        return fd_inc_rcvbuf(rtnl->fd, size);
+}
+
 sd_rtnl *sd_rtnl_ref(sd_rtnl *rtnl) {
         assert_return(rtnl, NULL);
         assert_return(!rtnl_pid_changed(rtnl), NULL);
@@ -118,7 +218,7 @@ sd_rtnl *sd_rtnl_unref(sd_rtnl *rtnl) {
 
         assert_return(!rtnl_pid_changed(rtnl), NULL);
 
-        if (REFCNT_DEC(rtnl->n_ref) <= 0) {
+        if (REFCNT_DEC(rtnl->n_ref) == 0) {
                 struct match_callback *f;
                 unsigned i;
 
@@ -126,10 +226,16 @@ sd_rtnl *sd_rtnl_unref(sd_rtnl *rtnl) {
                         sd_rtnl_message_unref(rtnl->rqueue[i]);
                 free(rtnl->rqueue);
 
+                for (i = 0; i < rtnl->rqueue_partial_size; i++)
+                        sd_rtnl_message_unref(rtnl->rqueue_partial[i]);
+                free(rtnl->rqueue_partial);
+
                 for (i = 0; i < rtnl->wqueue_size; i++)
                         sd_rtnl_message_unref(rtnl->wqueue[i]);
                 free(rtnl->wqueue);
 
+                free(rtnl->rbuffer);
+
                 hashmap_free_free(rtnl->reply_callbacks);
                 prioq_free(rtnl->reply_callbacks_prioq);
 
@@ -156,7 +262,9 @@ static void rtnl_seal_message(sd_rtnl *rtnl, sd_rtnl_message *m) {
         assert(m);
         assert(m->hdr);
 
-        m->hdr->nlmsg_seq = rtnl->serial++;
+        /* don't use seq == 0, as that is used for broadcasts, so we
+           would get confused by replies to such messages */
+        m->hdr->nlmsg_seq = rtnl->serial++ ? : rtnl->serial++;
 
         rtnl_message_seal(m);
 
@@ -188,8 +296,10 @@ int sd_rtnl_send(sd_rtnl *nl,
                 }
         } else {
                 /* append to queue */
-                if (nl->wqueue_size >= RTNL_WQUEUE_MAX)
+                if (nl->wqueue_size >= RTNL_WQUEUE_MAX) {
+                        log_debug("rtnl: exhausted the write queue size (%d)", RTNL_WQUEUE_MAX);
                         return -ENOBUFS;
+                }
 
                 if (!GREEDY_REALLOC(nl->wqueue, nl->wqueue_allocated, nl->wqueue_size + 1))
                         return -ENOMEM;
@@ -206,8 +316,10 @@ int sd_rtnl_send(sd_rtnl *nl,
 int rtnl_rqueue_make_room(sd_rtnl *rtnl) {
         assert(rtnl);
 
-        if (rtnl->rqueue_size >= RTNL_RQUEUE_MAX)
+        if (rtnl->rqueue_size >= RTNL_RQUEUE_MAX) {
+                log_debug("rtnl: exhausted the read queue size (%d)", RTNL_RQUEUE_MAX);
                 return -ENOBUFS;
+        }
 
         if (!GREEDY_REALLOC(rtnl->rqueue, rtnl->rqueue_allocated, rtnl->rqueue_size + 1))
                 return -ENOMEM;
@@ -215,6 +327,21 @@ int rtnl_rqueue_make_room(sd_rtnl *rtnl) {
         return 0;
 }
 
+int rtnl_rqueue_partial_make_room(sd_rtnl *rtnl) {
+        assert(rtnl);
+
+        if (rtnl->rqueue_partial_size >= RTNL_RQUEUE_MAX) {
+                log_debug("rtnl: exhausted the partial read queue size (%d)", RTNL_RQUEUE_MAX);
+                return -ENOBUFS;
+        }
+
+        if (!GREEDY_REALLOC(rtnl->rqueue_partial, rtnl->rqueue_partial_allocated,
+                            rtnl->rqueue_partial_size + 1))
+                return -ENOMEM;
+
+        return 0;
+}
+
 static int dispatch_rqueue(sd_rtnl *rtnl, sd_rtnl_message **message) {
         int r;
 
@@ -285,22 +412,23 @@ static int process_timeout(sd_rtnl *rtnl) {
         hashmap_remove(rtnl->reply_callbacks, &c->serial);
 
         r = c->callback(rtnl, m, c->userdata);
+        if (r < 0)
+                log_debug_errno(r, "sd-rtnl: timedout callback failed: %m");
+
         free(c);
 
-        return r < 0 ? r : 1;
+        return 1;
 }
 
 static int process_reply(sd_rtnl *rtnl, sd_rtnl_message *m) {
-        struct reply_callback *c;
+        _cleanup_free_ struct reply_callback *c = NULL;
         uint64_t serial;
+        uint16_t type;
         int r;
 
         assert(rtnl);
         assert(m);
 
-        if (sd_rtnl_message_is_broadcast(m))
-                return 0;
-
         serial = rtnl_message_get_serial(m);
         c = hashmap_remove(rtnl->reply_callbacks, &serial);
         if (!c)
@@ -309,10 +437,18 @@ static int process_reply(sd_rtnl *rtnl, sd_rtnl_message *m) {
         if (c->timeout != 0)
                 prioq_remove(rtnl->reply_callbacks_prioq, c, &c->prioq_idx);
 
+        r = sd_rtnl_message_get_type(m, &type);
+        if (r < 0)
+                return 0;
+
+        if (type == NLMSG_DONE)
+                m = NULL;
+
         r = c->callback(rtnl, m, c->userdata);
-        free(c);
+        if (r < 0)
+                log_debug_errno(r, "sd-rtnl: callback failed: %m");
 
-        return r;
+        return 1;
 }
 
 static int process_match(sd_rtnl *rtnl, sd_rtnl_message *m) {
@@ -330,12 +466,16 @@ static int process_match(sd_rtnl *rtnl, sd_rtnl_message *m) {
         LIST_FOREACH(match_callbacks, c, rtnl->match_callbacks) {
                 if (type == c->type) {
                         r = c->callback(rtnl, m, c->userdata);
-                        if (r != 0)
-                                return r;
+                        if (r != 0) {
+                                if (r < 0)
+                                        log_debug_errno(r, "sd-rtnl: match callback failed: %m");
+
+                                break;
+                        }
                 }
         }
 
-        return 0;
+        return 1;
 }
 
 static int process_running(sd_rtnl *rtnl, sd_rtnl_message **ret) {
@@ -358,13 +498,15 @@ static int process_running(sd_rtnl *rtnl, sd_rtnl_message **ret) {
         if (!m)
                 goto null_message;
 
-        r = process_reply(rtnl, m);
-        if (r != 0)
-                goto null_message;
-
-        r = process_match(rtnl, m);
-        if (r != 0)
-                goto null_message;
+        if (sd_rtnl_message_is_broadcast(m)) {
+                r = process_match(rtnl, m);
+                if (r != 0)
+                        goto null_message;
+        } else {
+                r = process_reply(rtnl, m);
+                if (r != 0)
+                        goto null_message;
+        }
 
         if (ret) {
                 *ret = m;
@@ -410,7 +552,7 @@ static usec_t calc_elapse(uint64_t usec) {
 static int rtnl_poll(sd_rtnl *rtnl, bool need_more, uint64_t timeout_usec) {
         struct pollfd p[1] = {};
         struct timespec ts;
-        usec_t m = (usec_t) -1;
+        usec_t m = USEC_INFINITY;
         int r, e;
 
         assert(rtnl);
@@ -422,7 +564,7 @@ static int rtnl_poll(sd_rtnl *rtnl, bool need_more, uint64_t timeout_usec) {
         if (need_more)
                 /* Caller wants more data, and doesn't care about
                  * what's been read or any other timeouts. */
-                return e |= POLLIN;
+                e |= POLLIN;
         else {
                 usec_t until;
                 /* Caller wants to process if there is something to
@@ -494,7 +636,7 @@ int sd_rtnl_call_async(sd_rtnl *nl,
         assert_return(callback, -EINVAL);
         assert_return(!rtnl_pid_changed(nl), -ECHILD);
 
-        r = hashmap_ensure_allocated(&nl->reply_callbacks, uint64_hash_func, uint64_compare_func);
+        r = hashmap_ensure_allocated(&nl->reply_callbacks, &uint64_hash_ops);
         if (r < 0)
                 return r;
 
@@ -566,7 +708,6 @@ int sd_rtnl_call(sd_rtnl *rtnl,
                 sd_rtnl_message **ret) {
         usec_t timeout;
         uint32_t serial;
-        unsigned i = 0;
         int r;
 
         assert_return(rtnl, -EINVAL);
@@ -581,43 +722,51 @@ int sd_rtnl_call(sd_rtnl *rtnl,
 
         for (;;) {
                 usec_t left;
+                unsigned i;
 
-                while (i < rtnl->rqueue_size) {
-                        sd_rtnl_message *incoming;
+                for (i = 0; i < rtnl->rqueue_size; i++) {
                         uint32_t received_serial;
 
-                        incoming = rtnl->rqueue[i];
-                        received_serial = rtnl_message_get_serial(incoming);
+                        received_serial = rtnl_message_get_serial(rtnl->rqueue[i]);
 
                         if (received_serial == serial) {
+                                _cleanup_rtnl_message_unref_ sd_rtnl_message *incoming = NULL;
+                                uint16_t type;
+
+                                incoming = rtnl->rqueue[i];
+
                                 /* found a match, remove from rqueue and return it */
                                 memmove(rtnl->rqueue + i,rtnl->rqueue + i + 1,
                                         sizeof(sd_rtnl_message*) * (rtnl->rqueue_size - i - 1));
                                 rtnl->rqueue_size--;
 
                                 r = sd_rtnl_message_get_errno(incoming);
-                                if (r < 0) {
-                                        sd_rtnl_message_unref(incoming);
+                                if (r < 0)
                                         return r;
+
+                                r = sd_rtnl_message_get_type(incoming, &type);
+                                if (r < 0)
+                                        return r;
+
+                                if (type == NLMSG_DONE) {
+                                        *ret = NULL;
+                                        return 0;
                                 }
 
                                 if (ret) {
                                         *ret = incoming;
-                                } else
-                                        sd_rtnl_message_unref(incoming);
+                                        incoming = NULL;
+                                }
 
                                 return 1;
                         }
-
-                        /* Try to read more, right away */
-                        i ++;
                 }
 
                 r = socket_read_message(rtnl);
                 if (r < 0)
                         return r;
                 if (r > 0)
-                        /* receieved message, so try to process straight away */
+                        /* received message, so try to process straight away */
                         continue;
 
                 if (timeout > 0) {
@@ -634,6 +783,8 @@ int sd_rtnl_call(sd_rtnl *rtnl,
                 r = rtnl_poll(rtnl, true, left);
                 if (r < 0)
                         return r;
+                else if (r == 0)
+                        return -ETIMEDOUT;
 
                 r = dispatch_wqueue(rtnl);
                 if (r < 0)
@@ -796,6 +947,10 @@ int sd_rtnl_attach_event(sd_rtnl *rtnl, sd_event *event, int priority) {
         if (r < 0)
                 goto fail;
 
+        r = sd_event_source_set_description(rtnl->io_event_source, "rtnl-receive-message");
+        if (r < 0)
+                goto fail;
+
         r = sd_event_source_set_prepare(rtnl->io_event_source, prepare_callback);
         if (r < 0)
                 goto fail;
@@ -808,10 +963,18 @@ int sd_rtnl_attach_event(sd_rtnl *rtnl, sd_event *event, int priority) {
         if (r < 0)
                 goto fail;
 
+        r = sd_event_source_set_description(rtnl->time_event_source, "rtnl-timer");
+        if (r < 0)
+                goto fail;
+
         r = sd_event_add_exit(rtnl->event, &rtnl->exit_event_source, exit_callback, rtnl);
         if (r < 0)
                 goto fail;
 
+        r = sd_event_source_set_description(rtnl->exit_event_source, "rtnl-exit");
+        if (r < 0)
+                goto fail;
+
         return 0;
 
 fail:
@@ -849,7 +1012,7 @@ int sd_rtnl_add_match(sd_rtnl *rtnl,
         assert_return(!rtnl_pid_changed(rtnl), -ECHILD);
         assert_return(rtnl_message_type_is_link(type) ||
                       rtnl_message_type_is_addr(type) ||
-                      rtnl_message_type_is_route(type), -ENOTSUP);
+                      rtnl_message_type_is_route(type), -EOPNOTSUPP);
 
         c = new0(struct match_callback, 1);
         if (!c)