chiark / gitweb /
bus: put together messages with memfd payload correctly
[elogind.git] / src / libsystemd-bus / bus-message.c
index 747b44ac942944711c28879670a8ac30baf5fc9e..b5a311530b623114d38d728d5b0c6a6fda16fcc1 100644 (file)
@@ -2184,7 +2184,7 @@ int sd_bus_message_append_array_memfd(sd_bus_message *m, char type, sd_memfd *me
         return sd_bus_message_close_container(m);
 }
 
-static int body_part_map_for_read(struct bus_body_part *part) {
+int bus_body_part_map(struct bus_body_part *part) {
         void *p;
         size_t psz;
 
@@ -2210,9 +2210,36 @@ static int body_part_map_for_read(struct bus_body_part *part) {
 
         part->mapped = psz;
         part->data = p;
+        part->munmap_this = true;
+
         return 0;
 }
 
+void bus_body_part_unmap(struct bus_body_part *part) {
+
+        assert_se(part);
+
+        if (part->memfd < 0)
+                return;
+
+        if (!part->sealed)
+                return;
+
+        if (!part->data)
+                return;
+
+        if (!part->munmap_this)
+                return;
+
+        assert_se(munmap(part->data, part->mapped) == 0);
+
+        part->data = NULL;
+        part->mapped = 0;
+        part->munmap_this = false;
+
+        return;
+}
+
 static int buffer_peek(const void *p, uint32_t sz, size_t *rindex, size_t align, size_t nbytes, void **r) {
         size_t k, start, end;
 
@@ -2271,7 +2298,7 @@ static struct bus_body_part* find_part(sd_bus_message *m, size_t index, size_t s
 
                 if (index + sz <= begin + part->size) {
 
-                        r = body_part_map_for_read(part);
+                        r = bus_body_part_map(part);
                         if (r < 0)
                                 return NULL;
 
@@ -3709,8 +3736,10 @@ int bus_message_seal(sd_bus_message *m, uint64_t serial) {
 
         MESSAGE_FOREACH_PART(part, i, m)
                 if (part->memfd >= 0 && !part->sealed) {
-                        ioctl(part->memfd, KDBUS_CMD_MEMFD_SEAL_SET, 1);
-                        part->sealed = true;
+                        bus_body_part_unmap(part);
+
+                        if (ioctl(part->memfd, KDBUS_CMD_MEMFD_SEAL_SET, 1) >= 0)
+                                part->sealed = true;
                 }
 
         m->header->serial = serial;