chiark / gitweb /
sd-bus: add a "recursive" mode to sd_bus_track
authorLennart Poettering <lennart@poettering.net>
Mon, 15 Aug 2016 12:58:09 +0000 (14:58 +0200)
committerSven Eden <yamakuzure@gmx.net>
Wed, 5 Jul 2017 06:50:51 +0000 (08:50 +0200)
This adds an optional "recursive" counting mode to sd_bus_track. If enabled
adding the same name multiple times to an sd_bus_track object is counted
individually, so that it also has to be removed the same number of times before
it is gone again from the tracking object.

This functionality is useful for implementing local ref counted objects that
peers make take references on.

src/libelogind/sd-bus/bus-track.c

index b2e562358bfd461c5a8831ef3ab5ba46ab8acc92..d2c26746ed5ca87a496e28346e0d4dc016b4cad5 100644 (file)
 #include "bus-track.h"
 #include "bus-util.h"
 
+struct track_item {
+        unsigned n_ref;
+        char *name;
+        sd_bus_slot *slot;
+};
+
 struct sd_bus_track {
         unsigned n_ref;
         sd_bus *bus;
@@ -32,8 +38,9 @@ struct sd_bus_track {
         Hashmap *names;
         LIST_FIELDS(sd_bus_track, queue);
         Iterator iterator;
-        bool in_queue;
-        bool modified;
+        bool in_queue:1;
+        bool modified:1;
+        bool recursive:1;
 };
 
 #define MATCH_PREFIX                                        \
@@ -56,6 +63,20 @@ struct sd_bus_track {
                 _x;                                                     \
         })
 
+static struct track_item* track_item_free(struct track_item *i) {
+
+        if (!i)
+                return NULL;
+
+        sd_bus_slot_unref(i->slot);
+        free(i->name);
+        free(i);
+
+        return NULL;
+}
+
+DEFINE_TRIVIAL_CLEANUP_FUNC(struct track_item*, track_item_free);
+
 static void bus_track_add_to_queue(sd_bus_track *track) {
         assert(track);
 
@@ -79,6 +100,25 @@ static void bus_track_remove_from_queue(sd_bus_track *track) {
         track->in_queue = false;
 }
 
+static int bus_track_remove_name_fully(sd_bus_track *track, const char *name) {
+        struct track_item *i;
+
+        assert(track);
+        assert(name);
+
+        i = hashmap_remove(track->names, name);
+        if (!i)
+                return 0;
+
+        track_item_free(i);
+
+        if (hashmap_isempty(track->names))
+                bus_track_add_to_queue(track);
+
+        track->modified = true;
+        return 1;
+}
+
 _public_ int sd_bus_track_new(
                 sd_bus *bus,
                 sd_bus_track **track,
@@ -121,7 +161,7 @@ _public_ sd_bus_track* sd_bus_track_ref(sd_bus_track *track) {
 }
 
 _public_ sd_bus_track* sd_bus_track_unref(sd_bus_track *track) {
-        const char *n;
+        struct track_item *i;
 
         if (!track)
                 return NULL;
@@ -133,8 +173,8 @@ _public_ sd_bus_track* sd_bus_track_unref(sd_bus_track *track) {
                 return NULL;
         }
 
-        while ((n = hashmap_first_key(track->names)))
-                sd_bus_track_remove_name(track, n);
+        while ((i = hashmap_steal_first(track->names)))
+                track_item_free(i);
 
         bus_track_remove_from_queue(track);
         hashmap_free(track->names);
@@ -156,49 +196,64 @@ static int on_name_owner_changed(sd_bus_message *message, void *userdata, sd_bus
         if (r < 0)
                 return 0;
 
-        sd_bus_track_remove_name(track, name);
+        bus_track_remove_name_fully(track, name);
         return 0;
 }
 
 _public_ int sd_bus_track_add_name(sd_bus_track *track, const char *name) {
-        _cleanup_(sd_bus_slot_unrefp) sd_bus_slot *slot = NULL;
-        _cleanup_free_ char *n = NULL;
+        _cleanup_(track_item_freep) struct track_item *n = NULL;
+        struct track_item *i;
         const char *match;
         int r;
 
         assert_return(track, -EINVAL);
         assert_return(service_name_is_valid(name), -EINVAL);
 
+        i = hashmap_get(track->names, name);
+        if (i) {
+                if (track->recursive) {
+                        unsigned k = track->n_ref + 1;
+
+                        if (k < track->n_ref) /* Check for overflow */
+                                return -EOVERFLOW;
+
+                        track->n_ref = k;
+                }
+
+                bus_track_remove_from_queue(track);
+                return 0;
+        }
+
         r = hashmap_ensure_allocated(&track->names, &string_hash_ops);
         if (r < 0)
                 return r;
 
-        n = strdup(name);
+        n = new0(struct track_item, 1);
         if (!n)
                 return -ENOMEM;
+        n->name = strdup(name);
+        if (!n->name)
+                return -ENOMEM;
 
         /* First, subscribe to this name */
-        match = MATCH_FOR_NAME(n);
-        r = sd_bus_add_match(track->bus, &slot, match, on_name_owner_changed, track);
+        match = MATCH_FOR_NAME(name);
+        r = sd_bus_add_match(track->bus, &n->slot, match, on_name_owner_changed, track);
         if (r < 0)
                 return r;
 
-        r = hashmap_put(track->names, n, slot);
-        if (r == -EEXIST)
-                return 0;
+        r = hashmap_put(track->names, n->name, n);
         if (r < 0)
                 return r;
 
-        /* Second, check if it is currently existing, or maybe
-         * doesn't, or maybe disappeared already. */
-        r = sd_bus_get_name_creds(track->bus, n, 0, NULL);
+        /* Second, check if it is currently existing, or maybe doesn't, or maybe disappeared already. */
+        r = sd_bus_get_name_creds(track->bus, name, 0, NULL);
         if (r < 0) {
-                hashmap_remove(track->names, n);
+                hashmap_remove(track->names, name);
                 return r;
         }
 
+        n->n_ref = 1;
         n = NULL;
-        slot = NULL;
 
         bus_track_remove_from_queue(track);
         track->modified = true;
@@ -207,37 +262,48 @@ _public_ int sd_bus_track_add_name(sd_bus_track *track, const char *name) {
 }
 
 _public_ int sd_bus_track_remove_name(sd_bus_track *track, const char *name) {
-        _cleanup_(sd_bus_slot_unrefp) sd_bus_slot *slot = NULL;
-        _cleanup_free_ char *n = NULL;
+        struct track_item *i;
 
         assert_return(name, -EINVAL);
 
-        if (!track)
+        if (!track) /* Treat a NULL track object as an empty track object */
                 return 0;
 
-        slot = hashmap_remove2(track->names, (char*) name, (void**) &n);
-        if (!slot)
-                return 0;
+        if (!track->recursive)
+                return bus_track_remove_name_fully(track, name);
 
-        if (hashmap_isempty(track->names))
-                bus_track_add_to_queue(track);
+        i = hashmap_get(track->names, name);
+        if (!i)
+                return -EUNATCH;
+        if (i->n_ref <= 0)
+                return -EUNATCH;
 
-        track->modified = true;
+        i->n_ref--;
+
+        if (i->n_ref <= 0)
+                return bus_track_remove_name_fully(track, name);
 
         return 1;
 }
 
 _public_ unsigned sd_bus_track_count(sd_bus_track *track) {
-        if (!track)
+
+        if (!track) /* Let's consider a NULL object equivalent to an empty object */
                 return 0;
 
+        /* This signature really should have returned an int, so that we can propagate errors. But well, ... Also, note
+         * that this returns the number of names being watched, and multiple references to the same name are not
+         * counted. */
+
         return hashmap_size(track->names);
 }
 
 _public_ const char* sd_bus_track_contains(sd_bus_track *track, const char *name) {
-        assert_return(track, NULL);
         assert_return(name, NULL);
 
+        if (!track) /* Let's consider a NULL object equivalent to an empty object */
+                return NULL;
+
         return hashmap_get(track->names, (void*) name) ? name : NULL;
 }
 
@@ -273,6 +339,9 @@ _public_ int sd_bus_track_add_sender(sd_bus_track *track, sd_bus_message *m) {
         assert_return(track, -EINVAL);
         assert_return(m, -EINVAL);
 
+        if (sd_bus_message_get_bus(m) != track->bus)
+                return -EINVAL;
+
         sender = sd_bus_message_get_sender(m);
         if (!sender)
                 return -EINVAL;
@@ -283,9 +352,14 @@ _public_ int sd_bus_track_add_sender(sd_bus_track *track, sd_bus_message *m) {
 _public_ int sd_bus_track_remove_sender(sd_bus_track *track, sd_bus_message *m) {
         const char *sender;
 
-        assert_return(track, -EINVAL);
         assert_return(m, -EINVAL);
 
+        if (!track) /* Treat a NULL track object as an empty track object */
+                return 0;
+
+        if (sd_bus_message_get_bus(m) != track->bus)
+                return -EINVAL;
+
         sender = sd_bus_message_get_sender(m);
         if (!sender)
                 return -EINVAL;
@@ -337,3 +411,55 @@ _public_ void *sd_bus_track_set_userdata(sd_bus_track *track, void *userdata) {
         return ret;
 }
 #endif // 0
+
+_public_ int sd_bus_track_set_recursive(sd_bus_track *track, int b) {
+        assert_return(track, -EINVAL);
+
+        if (track->recursive == !!b)
+                return 0;
+
+        if (!hashmap_isempty(track->names))
+                return -EBUSY;
+
+        track->recursive = b;
+        return 0;
+}
+
+_public_ int sd_bus_track_get_recursive(sd_bus_track *track) {
+        assert_return(track, -EINVAL);
+
+        return track->recursive;
+}
+
+_public_ int sd_bus_track_count_sender(sd_bus_track *track, sd_bus_message *m) {
+        const char *sender;
+
+        assert_return(m, -EINVAL);
+
+        if (!track) /* Let's consider a NULL object equivalent to an empty object */
+                return 0;
+
+        if (sd_bus_message_get_bus(m) != track->bus)
+                return -EINVAL;
+
+        sender = sd_bus_message_get_sender(m);
+        if (!sender)
+                return -EINVAL;
+
+        return sd_bus_track_count_name(track, sender);
+}
+
+_public_ int sd_bus_track_count_name(sd_bus_track *track, const char *name) {
+        struct track_item *i;
+
+        assert_return(service_name_is_valid(name), -EINVAL);
+
+        if (!track) /* Let's consider a NULL object equivalent to an empty object */
+                return 0;
+
+        i = hashmap_get(track->names, name);
+        if (!i)
+                return 0;
+
+        return i->n_ref;
+}