chiark / gitweb /
sd-bus: add support for matches against arrays of strings in messages
authorLennart Poettering <lennart@poettering.net>
Fri, 28 Nov 2014 18:16:37 +0000 (19:16 +0100)
committerLennart Poettering <lennart@poettering.net>
Fri, 28 Nov 2014 19:29:44 +0000 (20:29 +0100)
src/libsystemd/sd-bus/bus-match.c
src/libsystemd/sd-bus/bus-message.c
src/libsystemd/sd-bus/bus-message.h
src/libsystemd/sd-bus/test-bus-match.c

index 0e92a85ef756fdaf177e8abdcf9c61394c5531dc..3a31aa0ebfc9cc6a79470314332b795c62ae7190 100644 (file)
@@ -134,6 +134,7 @@ static bool value_node_test(
                 enum bus_match_node_type parent_type,
                 uint8_t value_u8,
                 const char *value_str,
+                char **value_strv,
                 sd_bus_message *m) {
 
         assert(node);
@@ -179,17 +180,46 @@ static bool value_node_test(
         case BUS_MATCH_INTERFACE:
         case BUS_MATCH_MEMBER:
         case BUS_MATCH_PATH:
-        case BUS_MATCH_ARG ... BUS_MATCH_ARG_LAST:
-                return streq_ptr(node->value.str, value_str);
+        case BUS_MATCH_ARG ... BUS_MATCH_ARG_LAST: {
+                char **i;
 
-        case BUS_MATCH_ARG_NAMESPACE ... BUS_MATCH_ARG_NAMESPACE_LAST:
-                return namespace_simple_pattern(node->value.str, value_str);
+                if (value_str)
+                        return streq_ptr(node->value.str, value_str);
+
+                STRV_FOREACH(i, value_strv)
+                        if (streq_ptr(node->value.str, *i))
+                                return true;
+
+                return false;
+        }
+
+        case BUS_MATCH_ARG_NAMESPACE ... BUS_MATCH_ARG_NAMESPACE_LAST: {
+                char **i;
+
+                if (value_str)
+                        return namespace_simple_pattern(node->value.str, value_str);
+
+                STRV_FOREACH(i, value_strv)
+                        if (namespace_simple_pattern(node->value.str, *i))
+                                return true;
+                return false;
+        }
 
         case BUS_MATCH_PATH_NAMESPACE:
                 return path_simple_pattern(node->value.str, value_str);
 
-        case BUS_MATCH_ARG_PATH ... BUS_MATCH_ARG_PATH_LAST:
-                return path_complex_pattern(node->value.str, value_str);
+        case BUS_MATCH_ARG_PATH ... BUS_MATCH_ARG_PATH_LAST: {
+                char **i;
+
+                if (value_str)
+                        return path_complex_pattern(node->value.str, value_str);
+
+                STRV_FOREACH(i, value_strv)
+                        if (path_complex_pattern(node->value.str, *i))
+                                return true;
+
+                return false;
+        }
 
         default:
                 assert_not_reached("Invalid node type");
@@ -235,7 +265,7 @@ int bus_match_run(
                 struct bus_match_node *node,
                 sd_bus_message *m) {
 
-
+        _cleanup_strv_free_ char **test_strv = NULL;
         const char *test_str = NULL;
         uint8_t test_u8 = 0;
         int r;
@@ -343,15 +373,15 @@ int bus_match_run(
                 break;
 
         case BUS_MATCH_ARG ... BUS_MATCH_ARG_LAST:
-                test_str = bus_message_get_arg(m, node->type - BUS_MATCH_ARG);
+                (void) bus_message_get_arg(m, node->type - BUS_MATCH_ARG, &test_str, &test_strv);
                 break;
 
         case BUS_MATCH_ARG_PATH ... BUS_MATCH_ARG_PATH_LAST:
-                test_str = bus_message_get_arg(m, node->type - BUS_MATCH_ARG_PATH);
+                (void) bus_message_get_arg(m, node->type - BUS_MATCH_ARG_PATH, &test_str, &test_strv);
                 break;
 
         case BUS_MATCH_ARG_NAMESPACE ... BUS_MATCH_ARG_NAMESPACE_LAST:
-                test_str = bus_message_get_arg(m, node->type - BUS_MATCH_ARG_NAMESPACE);
+                (void) bus_message_get_arg(m, node->type - BUS_MATCH_ARG_NAMESPACE, &test_str, &test_strv);
                 break;
 
         default:
@@ -365,7 +395,20 @@ int bus_match_run(
 
                 if (test_str)
                         found = hashmap_get(node->compare.children, test_str);
-                else if (node->type == BUS_MATCH_MESSAGE_TYPE)
+                else if (test_strv) {
+                        char **i;
+
+                        STRV_FOREACH(i, test_strv) {
+                                found = hashmap_get(node->compare.children, *i);
+                                if (found) {
+                                        r = bus_match_run(bus, found, m);
+                                        if (r != 0)
+                                                return r;
+                                }
+                        }
+
+                        found = NULL;
+                } else if (node->type == BUS_MATCH_MESSAGE_TYPE)
                         found = hashmap_get(node->compare.children, UINT_TO_PTR(test_u8));
                 else
                         found = NULL;
@@ -381,7 +424,7 @@ int bus_match_run(
                 /* No hash table, so let's iterate manually... */
 
                 for (c = node->child; c; c = c->next) {
-                        if (!value_node_test(c, node->type, test_u8, test_str, m))
+                        if (!value_node_test(c, node->type, test_u8, test_str, test_strv, m))
                                 continue;
 
                         r = bus_match_run(bus, c, m);
index 9fdf0d7e8117d842b6d98add0e893bf7bd676f76..05015a4157cd80044b5dd6d3244e5e2513773248 100644 (file)
@@ -5333,35 +5333,57 @@ _public_ int sd_bus_message_read_strv(sd_bus_message *m, char ***l) {
         return 1;
 }
 
-const char* bus_message_get_arg(sd_bus_message *m, unsigned i) {
-        int r;
-        const char *t = NULL;
+int bus_message_get_arg(sd_bus_message *m, unsigned i, const char **str, char ***strv) {
+        const char *contents;
         unsigned j;
+        char type;
+        int r;
 
         assert(m);
+        assert(str);
+        assert(strv);
 
         r = sd_bus_message_rewind(m, true);
         if (r < 0)
-                return NULL;
+                return r;
 
-        for (j = 0; j <= i; j++) {
-                char type;
+        for (j = 0;; j++) {
+                r = sd_bus_message_peek_type(m, &type, &contents);
+                if (r < 0)
+                        return r;
+                if (r == 0)
+                        return -ENXIO;
+
+                /* Don't match against arguments after the first one we don't understand */
+                if (!IN_SET(type, SD_BUS_TYPE_STRING, SD_BUS_TYPE_OBJECT_PATH, SD_BUS_TYPE_SIGNATURE) &&
+                    !(type == SD_BUS_TYPE_ARRAY && STR_IN_SET(contents, "s", "o", "g")))
+                        return -ENXIO;
+
+                if (j >= i)
+                        break;
 
-                r = sd_bus_message_peek_type(m, &type, NULL);
+                r = sd_bus_message_skip(m, NULL);
                 if (r < 0)
-                        return NULL;
+                        return r;
+        }
 
-                if (type != SD_BUS_TYPE_STRING &&
-                    type != SD_BUS_TYPE_OBJECT_PATH &&
-                    type != SD_BUS_TYPE_SIGNATURE)
-                        return NULL;
+        if (type == SD_BUS_TYPE_ARRAY) {
 
-                r = sd_bus_message_read_basic(m, type, &t);
+                r = sd_bus_message_read_strv(m, strv);
                 if (r < 0)
-                        return NULL;
+                        return r;
+
+                *str = NULL;
+
+        } else {
+                r = sd_bus_message_read_basic(m, type, str);
+                if (r < 0)
+                        return r;
+
+                *strv = NULL;
         }
 
-        return t;
+        return 0;
 }
 
 bool bus_header_is_complete(struct bus_header *h, size_t size) {
index 8aa71fa1d8d6cae66e1c65c598c202f53585ccdc..e54543a047b78b3b7dcb5027eeba25a69fef1f68 100644 (file)
@@ -223,7 +223,7 @@ int bus_message_from_malloc(
                 const char *label,
                 sd_bus_message **ret);
 
-const char* bus_message_get_arg(sd_bus_message *m, unsigned i);
+int bus_message_get_arg(sd_bus_message *m, unsigned i, const char **str, char ***strv);
 
 int bus_message_append_ap(sd_bus_message *m, const char *types, va_list ap);
 
index 6c5d35b3d326f312ee03f35b96793e24b1a246a2..713311703832a7e2fbf534fba0a39feba12a3250 100644 (file)
@@ -33,8 +33,9 @@
 static bool mask[32];
 
 static int filter(sd_bus *b, sd_bus_message *m, void *userdata, sd_bus_error *ret_error) {
-        log_info("Ran %i", PTR_TO_INT(userdata));
-        mask[PTR_TO_INT(userdata)] = true;
+        log_info("Ran %u", PTR_TO_UINT(userdata));
+        assert(PTR_TO_UINT(userdata) < ELEMENTSOF(mask));
+        mask[PTR_TO_UINT(userdata)] = true;
         return 0;
 }
 
@@ -85,9 +86,9 @@ int main(int argc, char *argv[]) {
         };
 
         _cleanup_bus_message_unref_ sd_bus_message *m = NULL;
-        _cleanup_bus_unref_ sd_bus *bus = NULL;
+        _cleanup_bus_close_unref_ sd_bus *bus = NULL;
         enum bus_match_node_type i;
-        sd_bus_slot slots[15];
+        sd_bus_slot slots[19];
         int r;
 
         r = sd_bus_open_system(&bus);
@@ -108,16 +109,20 @@ int main(int argc, char *argv[]) {
         assert_se(match_add(slots, &root, "arg1='two'", 12) >= 0);
         assert_se(match_add(slots, &root, "member='waldo',arg2path='/prefix/'", 13) >= 0);
         assert_se(match_add(slots, &root, "member=waldo,path='/foo/bar',arg3namespace='prefix'", 14) >= 0);
+        assert_se(match_add(slots, &root, "arg4='pi'", 15) >= 0);
+        assert_se(match_add(slots, &root, "arg4='pa'", 16) >= 0);
+        assert_se(match_add(slots, &root, "arg4='po'", 17) >= 0);
+        assert_se(match_add(slots, &root, "arg4='pu'", 18) >= 0);
 
         bus_match_dump(&root, 0);
 
         assert_se(sd_bus_message_new_signal(bus, &m, "/foo/bar", "bar.x", "waldo") >= 0);
-        assert_se(sd_bus_message_append(m, "ssss", "one", "two", "/prefix/three", "prefix.four") >= 0);
+        assert_se(sd_bus_message_append(m, "ssssas", "one", "two", "/prefix/three", "prefix.four", 3, "pi", "pa", "po") >= 0);
         assert_se(bus_message_seal(m, 1, 0) >= 0);
 
         zero(mask);
         assert_se(bus_match_run(NULL, &root, m) == 0);
-        assert_se(mask_contains((unsigned[]) { 9, 8, 7, 5, 10, 12, 13, 14 }, 8));
+        assert_se(mask_contains((unsigned[]) { 9, 8, 7, 5, 10, 12, 13, 14, 15, 16, 17 }, 11));
 
         assert_se(bus_match_remove(&root, &slots[8].match_callback) >= 0);
         assert_se(bus_match_remove(&root, &slots[13].match_callback) >= 0);
@@ -126,7 +131,7 @@ int main(int argc, char *argv[]) {
 
         zero(mask);
         assert_se(bus_match_run(NULL, &root, m) == 0);
-        assert_se(mask_contains((unsigned[]) { 9, 5, 10, 12, 14, 7 }, 6));
+        assert_se(mask_contains((unsigned[]) { 9, 5, 10, 12, 14, 7, 15, 16, 17 }, 9));
 
         for (i = 0; i < _BUS_MATCH_NODE_TYPE_MAX; i++) {
                 char buf[32];