chiark / gitweb /
sd-rtnl: fix self-reference leaks
[elogind.git] / src / libsystemd / sd-rtnl / sd-rtnl.c
index 05889656f95d9fb2633e36b5eec361c2a5f9d44c..60426576422cc9ed25319ca774415b7d1a03e6d3 100644 (file)
@@ -70,7 +70,7 @@ static bool rtnl_pid_changed(sd_rtnl *rtnl) {
         return rtnl->original_pid != getpid();
 }
 
-int sd_rtnl_open(uint32_t groups, sd_rtnl **ret) {
+int sd_rtnl_open(sd_rtnl **ret, uint32_t groups) {
         _cleanup_rtnl_unref_ sd_rtnl *rtnl = NULL;
         socklen_t addrlen;
         int r;
@@ -111,33 +111,82 @@ sd_rtnl *sd_rtnl_ref(sd_rtnl *rtnl) {
 }
 
 sd_rtnl *sd_rtnl_unref(sd_rtnl *rtnl) {
+        unsigned long refs;
 
-        if (rtnl && REFCNT_DEC(rtnl->n_ref) <= 0) {
+        if (!rtnl)
+                return NULL;
+
+        /*
+         * If our ref-cnt is exactly the number of internally queued messages
+         * plus the ref-cnt to be dropped, then we know there's no external
+         * reference to us. Hence, we look through all queued messages and if
+         * they also have no external references, we're about to drop the last
+         * ref. Flush the queues so the REFCNT_DEC() below will drop to 0.
+         * We must be careful not to introduce inter-message references or this
+         * logic will fall apart..
+         */
+
+        refs = rtnl->rqueue_size + rtnl->wqueue_size + 1;
+
+        if (REFCNT_GET(rtnl->n_ref) <= refs) {
                 struct match_callback *f;
+                bool q = true;
                 unsigned i;
 
-                for (i = 0; i < rtnl->rqueue_size; i++)
-                        sd_rtnl_message_unref(rtnl->rqueue[i]);
-                free(rtnl->rqueue);
+                for (i = 0; i < rtnl->rqueue_size; i++) {
+                        if (REFCNT_GET(rtnl->rqueue[i]->n_ref) > 1) {
+                                q = false;
+                                break;
+                        } else if (rtnl->rqueue[i]->rtnl != rtnl)
+                                --refs;
+                }
 
-                for (i = 0; i < rtnl->wqueue_size; i++)
-                        sd_rtnl_message_unref(rtnl->wqueue[i]);
-                free(rtnl->wqueue);
+                if (q) {
+                        for (i = 0; i < rtnl->wqueue_size; i++) {
+                                if (REFCNT_GET(rtnl->wqueue[i]->n_ref) > 1) {
+                                        q = false;
+                                        break;
+                                } else if (rtnl->wqueue[i]->rtnl != rtnl)
+                                        --refs;
+                        }
+                }
 
-                hashmap_free_free(rtnl->reply_callbacks);
-                prioq_free(rtnl->reply_callbacks_prioq);
+                if (q && REFCNT_GET(rtnl->n_ref) == refs) {
+                        /* Drop our own ref early to avoid recursion from:
+                         *   sd_rtnl_message_unref()
+                         *     sd_rtnl_unref()
+                         * These must enter sd_rtnl_unref() with a ref-cnt
+                         * smaller than us. */
+                        REFCNT_DEC(rtnl->n_ref);
 
-                while ((f = rtnl->match_callbacks)) {
-                        LIST_REMOVE(match_callbacks, rtnl->match_callbacks, f);
-                        free(f);
-                }
+                        for (i = 0; i < rtnl->rqueue_size; i++)
+                                sd_rtnl_message_unref(rtnl->rqueue[i]);
+                        free(rtnl->rqueue);
 
-                if (rtnl->fd >= 0)
-                        close_nointr_nofail(rtnl->fd);
+                        for (i = 0; i < rtnl->wqueue_size; i++)
+                                sd_rtnl_message_unref(rtnl->wqueue[i]);
+                        free(rtnl->wqueue);
 
-                free(rtnl);
+                        assert_se(REFCNT_GET(rtnl->n_ref) == 0);
+
+                        hashmap_free_free(rtnl->reply_callbacks);
+                        prioq_free(rtnl->reply_callbacks_prioq);
+
+                        while ((f = rtnl->match_callbacks)) {
+                                LIST_REMOVE(match_callbacks, rtnl->match_callbacks, f);
+                                free(f);
+                        }
+
+                        safe_close(rtnl->fd);
+                        free(rtnl);
+
+                        return NULL;
+                }
         }
 
+        assert_se(REFCNT_GET(rtnl->n_ref) > 0);
+        REFCNT_DEC(rtnl->n_ref);
+
         return NULL;
 }
 
@@ -277,6 +326,9 @@ static int process_reply(sd_rtnl *rtnl, sd_rtnl_message *m) {
         assert(rtnl);
         assert(m);
 
+        if (sd_rtnl_message_is_broadcast(m))
+                return 0;
+
         serial = rtnl_message_get_serial(m);
         c = hashmap_remove(rtnl->reply_callbacks, &serial);
         if (!c)
@@ -774,7 +826,7 @@ int sd_rtnl_attach_event(sd_rtnl *rtnl, sd_event *event, int priority) {
                         return r;
         }
 
-        r = sd_event_add_io(rtnl->event, rtnl->fd, 0, io_callback, rtnl, &rtnl->io_event_source);
+        r = sd_event_add_io(rtnl->event, &rtnl->io_event_source, rtnl->fd, 0, io_callback, rtnl);
         if (r < 0)
                 goto fail;
 
@@ -786,7 +838,7 @@ int sd_rtnl_attach_event(sd_rtnl *rtnl, sd_event *event, int priority) {
         if (r < 0)
                 goto fail;
 
-        r = sd_event_add_monotonic(rtnl->event, 0, 0, time_callback, rtnl, &rtnl->time_event_source);
+        r = sd_event_add_monotonic(rtnl->event, &rtnl->time_event_source, 0, 0, time_callback, rtnl);
         if (r < 0)
                 goto fail;
 
@@ -794,7 +846,7 @@ int sd_rtnl_attach_event(sd_rtnl *rtnl, sd_event *event, int priority) {
         if (r < 0)
                 goto fail;
 
-        r = sd_event_add_exit(rtnl->event, exit_callback, rtnl, &rtnl->exit_event_source);
+        r = sd_event_add_exit(rtnl->event, &rtnl->exit_event_source, exit_callback, rtnl);
         if (r < 0)
                 goto fail;