chiark / gitweb /
rtnl: message - read group membership of incoming messages
[elogind.git] / src / libsystemd / sd-rtnl / sd-rtnl.c
index 60426576422cc9ed25319ca774415b7d1a03e6d3..b91d08012ad78aaf52a621fcf59691be876e2a7c 100644 (file)
@@ -22,6 +22,7 @@
 #include <sys/socket.h>
 #include <poll.h>
 
+#include "missing.h"
 #include "macro.h"
 #include "util.h"
 #include "hashmap.h"
@@ -31,7 +32,7 @@
 #include "rtnl-util.h"
 
 static int sd_rtnl_new(sd_rtnl **ret) {
-        sd_rtnl *rtnl;
+        _cleanup_rtnl_unref_ sd_rtnl *rtnl = NULL;
 
         assert_return(ret, -EINVAL);
 
@@ -51,13 +52,18 @@ static int sd_rtnl_new(sd_rtnl **ret) {
 
         /* We guarantee that wqueue always has space for at least
          * one entry */
-        rtnl->wqueue = new(sd_rtnl_message*, 1);
-        if (!rtnl->wqueue) {
-                free(rtnl);
+        if (!GREEDY_REALLOC(rtnl->wqueue, rtnl->wqueue_allocated, 1))
+                return -ENOMEM;
+
+        /* We guarantee that the read buffer has at least space for
+         * a message header */
+        if (!greedy_realloc((void**)&rtnl->rbuffer, &rtnl->rbuffer_allocated,
+                            sizeof(struct nlmsghdr), sizeof(uint8_t)))
                 return -ENOMEM;
-        }
 
         *ret = rtnl;
+        rtnl = NULL;
+
         return 0;
 }
 
@@ -70,10 +76,29 @@ static bool rtnl_pid_changed(sd_rtnl *rtnl) {
         return rtnl->original_pid != getpid();
 }
 
-int sd_rtnl_open(sd_rtnl **ret, uint32_t groups) {
+static int rtnl_compute_groups_ap(uint32_t *_groups, unsigned n_groups, va_list ap) {
+        uint32_t groups = 0;
+        unsigned i;
+
+        for (i = 0; i < n_groups; i++) {
+                unsigned group;
+
+                group = va_arg(ap, unsigned);
+                assert_return(group < 32, -EINVAL);
+
+                groups |= group ? (1 << (group - 1)) : 0;
+        }
+
+        *_groups = groups;
+
+        return 0;
+}
+
+int sd_rtnl_open(sd_rtnl **ret, unsigned n_groups, ...) {
         _cleanup_rtnl_unref_ sd_rtnl *rtnl = NULL;
+        va_list ap;
         socklen_t addrlen;
-        int r;
+        int r, one = 1;
 
         assert_return(ret, -EINVAL);
 
@@ -85,7 +110,19 @@ int sd_rtnl_open(sd_rtnl **ret, uint32_t groups) {
         if (rtnl->fd < 0)
                 return -errno;
 
-        rtnl->sockaddr.nl.nl_groups = groups;
+        r = setsockopt(rtnl->fd, SOL_SOCKET, SO_PASSCRED, &one, sizeof(one));
+        if (r < 0)
+                return -errno;
+
+        r = setsockopt(rtnl->fd, SOL_NETLINK, NETLINK_PKTINFO, &one, sizeof(one));
+        if (r < 0)
+                return -errno;
+
+        va_start(ap, n_groups);
+        r = rtnl_compute_groups_ap(&rtnl->sockaddr.nl.nl_groups, n_groups, ap);
+        va_end(ap);
+        if (r < 0)
+                return r;
 
         addrlen = sizeof(rtnl->sockaddr);
 
@@ -104,6 +141,9 @@ int sd_rtnl_open(sd_rtnl **ret, uint32_t groups) {
 }
 
 sd_rtnl *sd_rtnl_ref(sd_rtnl *rtnl) {
+        assert_return(rtnl, NULL);
+        assert_return(!rtnl_pid_changed(rtnl), NULL);
+
         if (rtnl)
                 assert_se(REFCNT_INC(rtnl->n_ref) >= 2);
 
@@ -111,83 +151,60 @@ sd_rtnl *sd_rtnl_ref(sd_rtnl *rtnl) {
 }
 
 sd_rtnl *sd_rtnl_unref(sd_rtnl *rtnl) {
-        unsigned long refs;
-
         if (!rtnl)
                 return NULL;
 
-        /*
-         * If our ref-cnt is exactly the number of internally queued messages
-         * plus the ref-cnt to be dropped, then we know there's no external
-         * reference to us. Hence, we look through all queued messages and if
-         * they also have no external references, we're about to drop the last
-         * ref. Flush the queues so the REFCNT_DEC() below will drop to 0.
-         * We must be careful not to introduce inter-message references or this
-         * logic will fall apart..
-         */
+        assert_return(!rtnl_pid_changed(rtnl), NULL);
 
-        refs = rtnl->rqueue_size + rtnl->wqueue_size + 1;
-
-        if (REFCNT_GET(rtnl->n_ref) <= refs) {
+        if (REFCNT_DEC(rtnl->n_ref) <= 0) {
                 struct match_callback *f;
-                bool q = true;
                 unsigned i;
 
-                for (i = 0; i < rtnl->rqueue_size; i++) {
-                        if (REFCNT_GET(rtnl->rqueue[i]->n_ref) > 1) {
-                                q = false;
-                                break;
-                        } else if (rtnl->rqueue[i]->rtnl != rtnl)
-                                --refs;
-                }
+                for (i = 0; i < rtnl->rqueue_size; i++)
+                        sd_rtnl_message_unref(rtnl->rqueue[i]);
+                free(rtnl->rqueue);
 
-                if (q) {
-                        for (i = 0; i < rtnl->wqueue_size; i++) {
-                                if (REFCNT_GET(rtnl->wqueue[i]->n_ref) > 1) {
-                                        q = false;
-                                        break;
-                                } else if (rtnl->wqueue[i]->rtnl != rtnl)
-                                        --refs;
-                        }
-                }
+                for (i = 0; i < rtnl->rqueue_partial_size; i++)
+                        sd_rtnl_message_unref(rtnl->rqueue_partial[i]);
+                free(rtnl->rqueue_partial);
 
-                if (q && REFCNT_GET(rtnl->n_ref) == refs) {
-                        /* Drop our own ref early to avoid recursion from:
-                         *   sd_rtnl_message_unref()
-                         *     sd_rtnl_unref()
-                         * These must enter sd_rtnl_unref() with a ref-cnt
-                         * smaller than us. */
-                        REFCNT_DEC(rtnl->n_ref);
+                for (i = 0; i < rtnl->wqueue_size; i++)
+                        sd_rtnl_message_unref(rtnl->wqueue[i]);
+                free(rtnl->wqueue);
 
-                        for (i = 0; i < rtnl->rqueue_size; i++)
-                                sd_rtnl_message_unref(rtnl->rqueue[i]);
-                        free(rtnl->rqueue);
+                free(rtnl->rbuffer);
 
-                        for (i = 0; i < rtnl->wqueue_size; i++)
-                                sd_rtnl_message_unref(rtnl->wqueue[i]);
-                        free(rtnl->wqueue);
+                hashmap_free_free(rtnl->reply_callbacks);
+                prioq_free(rtnl->reply_callbacks_prioq);
 
-                        assert_se(REFCNT_GET(rtnl->n_ref) == 0);
+                sd_event_source_unref(rtnl->io_event_source);
+                sd_event_source_unref(rtnl->time_event_source);
+                sd_event_source_unref(rtnl->exit_event_source);
+                sd_event_unref(rtnl->event);
 
-                        hashmap_free_free(rtnl->reply_callbacks);
-                        prioq_free(rtnl->reply_callbacks_prioq);
+                while ((f = rtnl->match_callbacks)) {
+                        LIST_REMOVE(match_callbacks, rtnl->match_callbacks, f);
+                        free(f);
+                }
 
-                        while ((f = rtnl->match_callbacks)) {
-                                LIST_REMOVE(match_callbacks, rtnl->match_callbacks, f);
-                                free(f);
-                        }
+                safe_close(rtnl->fd);
+                free(rtnl);
+        }
 
-                        safe_close(rtnl->fd);
-                        free(rtnl);
+        return NULL;
+}
 
-                        return NULL;
-                }
-        }
+static void rtnl_seal_message(sd_rtnl *rtnl, sd_rtnl_message *m) {
+        assert(rtnl);
+        assert(!rtnl_pid_changed(rtnl));
+        assert(m);
+        assert(m->hdr);
 
-        assert_se(REFCNT_GET(rtnl->n_ref) > 0);
-        REFCNT_DEC(rtnl->n_ref);
+        m->hdr->nlmsg_seq = rtnl->serial++;
 
-        return NULL;
+        rtnl_message_seal(m);
+
+        return;
 }
 
 int sd_rtnl_send(sd_rtnl *nl,
@@ -198,10 +215,9 @@ int sd_rtnl_send(sd_rtnl *nl,
         assert_return(nl, -EINVAL);
         assert_return(!rtnl_pid_changed(nl), -ECHILD);
         assert_return(message, -EINVAL);
+        assert_return(!message->sealed, -EPERM);
 
-        r = rtnl_message_seal(nl, message);
-        if (r < 0)
-                return r;
+        rtnl_seal_message(nl, message);
 
         if (nl->wqueue_size <= 0) {
                 /* send directly */
@@ -215,18 +231,16 @@ int sd_rtnl_send(sd_rtnl *nl,
                         nl->wqueue_size = 1;
                 }
         } else {
-                sd_rtnl_message **q;
-
                 /* append to queue */
-                if (nl->wqueue_size >= RTNL_WQUEUE_MAX)
+                if (nl->wqueue_size >= RTNL_WQUEUE_MAX) {
+                        log_debug("rtnl: exhausted the write queue size (%d)", RTNL_WQUEUE_MAX);
                         return -ENOBUFS;
+                }
 
-                q = realloc(nl->wqueue, sizeof(sd_rtnl_message*) * (nl->wqueue_size + 1));
-                if (!q)
+                if (!GREEDY_REALLOC(nl->wqueue, nl->wqueue_allocated, nl->wqueue_size + 1))
                         return -ENOMEM;
 
-                nl->wqueue = q;
-                q[nl->wqueue_size ++] = sd_rtnl_message_ref(message);
+                nl->wqueue[nl->wqueue_size ++] = sd_rtnl_message_ref(message);
         }
 
         if (serial)
@@ -235,31 +249,52 @@ int sd_rtnl_send(sd_rtnl *nl,
         return 1;
 }
 
+int rtnl_rqueue_make_room(sd_rtnl *rtnl) {
+        assert(rtnl);
+
+        if (rtnl->rqueue_size >= RTNL_RQUEUE_MAX) {
+                log_debug("rtnl: exhausted the read queue size (%d)", RTNL_RQUEUE_MAX);
+                return -ENOBUFS;
+        }
+
+        if (!GREEDY_REALLOC(rtnl->rqueue, rtnl->rqueue_allocated, rtnl->rqueue_size + 1))
+                return -ENOMEM;
+
+        return 0;
+}
+
+int rtnl_rqueue_partial_make_room(sd_rtnl *rtnl) {
+        assert(rtnl);
+
+        if (rtnl->rqueue_partial_size >= RTNL_RQUEUE_MAX) {
+                log_debug("rtnl: exhausted the partial read queue size (%d)", RTNL_RQUEUE_MAX);
+                return -ENOBUFS;
+        }
+
+        if (!GREEDY_REALLOC(rtnl->rqueue_partial, rtnl->rqueue_partial_allocated,
+                            rtnl->rqueue_partial_size + 1))
+                return -ENOMEM;
+
+        return 0;
+}
+
 static int dispatch_rqueue(sd_rtnl *rtnl, sd_rtnl_message **message) {
-        sd_rtnl_message *z = NULL;
         int r;
 
         assert(rtnl);
         assert(message);
 
-        if (rtnl->rqueue_size > 0) {
-                /* Dispatch a queued message */
-
-                *message = rtnl->rqueue[0];
-                rtnl->rqueue_size --;
-                memmove(rtnl->rqueue, rtnl->rqueue + 1, sizeof(sd_rtnl_message*) * rtnl->rqueue_size);
-
-                return 1;
+        if (rtnl->rqueue_size <= 0) {
+                /* Try to read a new message */
+                r = socket_read_message(rtnl);
+                if (r <= 0)
+                        return r;
         }
 
-        /* Try to read a new message */
-        r = socket_read_message(rtnl, &z);
-        if (r < 0)
-                return r;
-        if (r == 0)
-                return 0;
-
-        *message = z;
+        /* Dispatch a queued message */
+        *message = rtnl->rqueue[0];
+        rtnl->rqueue_size --;
+        memmove(rtnl->rqueue, rtnl->rqueue + 1, sizeof(sd_rtnl_message*) * rtnl->rqueue_size);
 
         return 1;
 }
@@ -588,20 +623,20 @@ int sd_rtnl_call_async_cancel(sd_rtnl *nl, uint32_t serial) {
         return 1;
 }
 
-int sd_rtnl_call(sd_rtnl *nl,
+int sd_rtnl_call(sd_rtnl *rtnl,
                 sd_rtnl_message *message,
                 uint64_t usec,
                 sd_rtnl_message **ret) {
         usec_t timeout;
         uint32_t serial;
-        bool room = false;
+        unsigned i = 0;
         int r;
 
-        assert_return(nl, -EINVAL);
-        assert_return(!rtnl_pid_changed(nl), -ECHILD);
+        assert_return(rtnl, -EINVAL);
+        assert_return(!rtnl_pid_changed(rtnl), -ECHILD);
         assert_return(message, -EINVAL);
 
-        r = sd_rtnl_send(nl, message, &serial);
+        r = sd_rtnl_send(rtnl, message, &serial);
         if (r < 0)
                 return r;
 
@@ -609,53 +644,43 @@ int sd_rtnl_call(sd_rtnl *nl,
 
         for (;;) {
                 usec_t left;
-                _cleanup_rtnl_message_unref_ sd_rtnl_message *incoming = NULL;
-
-                if (!room) {
-                        sd_rtnl_message **q;
-
-                        if (nl->rqueue_size >= RTNL_RQUEUE_MAX)
-                                return -ENOBUFS;
 
-                        /* Make sure there's room for queueing this
-                         * locally, before we read the message */
+                while (i < rtnl->rqueue_size) {
+                        sd_rtnl_message *incoming;
+                        uint32_t received_serial;
 
-                        q = realloc(nl->rqueue, (nl->rqueue_size + 1) * sizeof(sd_rtnl_message*));
-                        if (!q)
-                                return -ENOMEM;
-
-                        nl->rqueue = q;
-                        room = true;
-                }
-
-                r = socket_read_message(nl, &incoming);
-                if (r < 0)
-                        return r;
-                if (incoming) {
-                        uint32_t received_serial = rtnl_message_get_serial(incoming);
+                        incoming = rtnl->rqueue[i];
+                        received_serial = rtnl_message_get_serial(incoming);
 
                         if (received_serial == serial) {
+                                /* found a match, remove from rqueue and return it */
+                                memmove(rtnl->rqueue + i,rtnl->rqueue + i + 1,
+                                        sizeof(sd_rtnl_message*) * (rtnl->rqueue_size - i - 1));
+                                rtnl->rqueue_size--;
+
                                 r = sd_rtnl_message_get_errno(incoming);
-                                if (r < 0)
+                                if (r < 0) {
+                                        sd_rtnl_message_unref(incoming);
                                         return r;
+                                }
 
                                 if (ret) {
                                         *ret = incoming;
-                                        incoming = NULL;
-                                }
+                                } else
+                                        sd_rtnl_message_unref(incoming);
 
                                 return 1;
                         }
 
-                        /* Room was allocated on the queue above */
-                        nl->rqueue[nl->rqueue_size ++] = incoming;
-                        incoming = NULL;
-                        room = false;
-
                         /* Try to read more, right away */
-                        continue;
+                        i ++;
                 }
-                if (r != 0)
+
+                r = socket_read_message(rtnl);
+                if (r < 0)
+                        return r;
+                if (r > 0)
+                        /* receieved message, so try to process straight away */
                         continue;
 
                 if (timeout > 0) {
@@ -669,11 +694,11 @@ int sd_rtnl_call(sd_rtnl *nl,
                 } else
                         left = (uint64_t) -1;
 
-                r = rtnl_poll(nl, true, left);
+                r = rtnl_poll(rtnl, true, left);
                 if (r < 0)
                         return r;
 
-                r = dispatch_wqueue(nl);
+                r = dispatch_wqueue(rtnl);
                 if (r < 0)
                         return r;
         }
@@ -838,7 +863,7 @@ int sd_rtnl_attach_event(sd_rtnl *rtnl, sd_event *event, int priority) {
         if (r < 0)
                 goto fail;
 
-        r = sd_event_add_monotonic(rtnl->event, &rtnl->time_event_source, 0, 0, time_callback, rtnl);
+        r = sd_event_add_time(rtnl->event, &rtnl->time_event_source, CLOCK_MONOTONIC, 0, 0, time_callback, rtnl);
         if (r < 0)
                 goto fail;