chiark / gitweb /
bus: rework message struct to keep header with fields in same malloc() block
authorLennart Poettering <lennart@poettering.net>
Wed, 15 May 2013 00:56:45 +0000 (02:56 +0200)
committerLennart Poettering <lennart@poettering.net>
Wed, 15 May 2013 00:56:45 +0000 (02:56 +0200)
This allows us to guarantee that the first payload_vec we pass to the
kernel for each message is guaranteed to include the full header and all
its field.

src/libsystemd-bus/bus-kernel.c
src/libsystemd-bus/bus-message.c
src/libsystemd-bus/bus-message.h
src/libsystemd-bus/bus-socket.c

index 3aa4084..8146936 100644 (file)
@@ -211,7 +211,7 @@ static int bus_message_setup_kmsg(sd_bus *b, sd_bus_message *m) {
         sz = offsetof(struct kdbus_msg, items);
 
         /* Add in fixed header, fields header and payload */
-        sz += (1 + !!m->fields + m->n_body_parts) *
+        sz += (1 + m->n_body_parts) *
                 ALIGN8(offsetof(struct kdbus_item, vec) + sizeof(struct kdbus_vec));
 
         /* Add space for bloom filter */
@@ -249,11 +249,7 @@ static int bus_message_setup_kmsg(sd_bus *b, sd_bus_message *m) {
         if (well_known)
                 append_destination(&d, m->destination, dl);
 
-        append_payload_vec(&d, m->header, sizeof(*m->header));
-
-        if (m->fields)
-                append_payload_vec(&d, m->fields, ALIGN8(m->header->fields_size));
-
+        append_payload_vec(&d, m->header, BUS_MESSAGE_BODY_BEGIN(m));
         MESSAGE_FOREACH_PART(part, i, m)
                 append_payload_vec(&d, part->data, part->size);
 
@@ -398,22 +394,6 @@ static void close_kdbus_msg(sd_bus *bus, struct kdbus_msg *k) {
         }
 }
 
-static bool range_contains(
-                size_t astart, size_t asize,
-                size_t bstart, size_t bsize,
-                void *a, void **b) {
-
-        if (bstart < astart)
-                return false;
-
-        if (bstart + bsize > astart + asize)
-                return false;
-
-        *b = (uint8_t*) a + (bstart - astart);
-
-        return true;
-}
-
 static int bus_kernel_make_message(sd_bus *bus, struct kdbus_msg *k, sd_bus_message **ret) {
         sd_bus_message *m = NULL;
         struct kdbus_item *d;
@@ -439,10 +419,10 @@ static int bus_kernel_make_message(sd_bus *bus, struct kdbus_msg *k, sd_bus_mess
                 if (d->type == KDBUS_MSG_PAYLOAD_VEC) {
 
                         if (!h) {
-                                if (d->vec.size < sizeof(struct bus_header))
-                                        return -EBADMSG;
-
                                 h = UINT64_TO_PTR(d->vec.address);
+
+                                if (!bus_header_is_complete(h, d->vec.size))
+                                        return -EBADMSG;
                         }
 
                         n_payload++;
@@ -470,7 +450,7 @@ static int bus_kernel_make_message(sd_bus *bus, struct kdbus_msg *k, sd_bus_mess
         if (!h)
                 return -EBADMSG;
 
-        r = bus_header_size(h, &total);
+        r = bus_header_message_size(h, &total);
         if (r < 0)
                 return r;
 
@@ -489,11 +469,7 @@ static int bus_kernel_make_message(sd_bus *bus, struct kdbus_msg *k, sd_bus_mess
                 if (d->type == KDBUS_MSG_PAYLOAD_VEC) {
                         size_t begin_body;
 
-                        /* Fill in fields material */
-                        range_contains(idx, d->vec.size, ALIGN8(sizeof(struct bus_header)), BUS_MESSAGE_FIELDS_SIZE(m),
-                                       UINT64_TO_PTR(d->vec.address), &m->fields);
-
-                        begin_body = ALIGN8(sizeof(struct bus_header)) + ALIGN8(BUS_MESSAGE_FIELDS_SIZE(m));
+                        begin_body = BUS_MESSAGE_BODY_BEGIN(m);
 
                         if (idx + d->vec.size > begin_body) {
                                 struct bus_body_part *part;
@@ -507,10 +483,10 @@ static int bus_kernel_make_message(sd_bus *bus, struct kdbus_msg *k, sd_bus_mess
                                 }
 
                                 if (idx >= begin_body) {
-                                        part->data = (void*) d->vec.address;
+                                        part->data = UINT64_TO_PTR(d->vec.address);
                                         part->size = d->vec.size;
                                 } else {
-                                        part->data = (uint8_t*) (uintptr_t) d->vec.address + (begin_body - idx);
+                                        part->data = (uint8_t*) UINT64_TO_PTR(d->vec.address) + (begin_body - idx);
                                         part->size = d->vec.size - (begin_body - idx);
                                 }
 
@@ -551,11 +527,6 @@ static int bus_kernel_make_message(sd_bus *bus, struct kdbus_msg *k, sd_bus_mess
                         log_debug("Got unknown field from kernel %llu", d->type);
         }
 
-        if ((BUS_MESSAGE_FIELDS_SIZE(m) > 0 && !m->fields)) {
-                sd_bus_message_unref(m);
-                return -EBADMSG;
-        }
-
         r = bus_message_parse_fields(m);
         if (r < 0) {
                 sd_bus_message_unref(m);
index 3790102..c1e1c46 100644 (file)
@@ -116,9 +116,6 @@ static void message_free(sd_bus_message *m) {
         if (m->free_header)
                 free(m->header);
 
-        if (m->free_fields)
-                free(m->fields);
-
         message_reset_parts(m);
 
         if (m->free_kdbus)
@@ -151,66 +148,64 @@ static void message_free(sd_bus_message *m) {
         free(m);
 }
 
-static void* buffer_extend(void **p, uint32_t *sz, size_t align, size_t extend) {
-        size_t start, end;
-        void *k;
-
-        assert(p);
-        assert(sz);
-        assert(align > 0);
-
-        start = ALIGN_TO((size_t) *sz, align);
-        end = start + extend;
-
-        if (end == *sz)
-                return (uint8_t*) *p + start;
+static void *message_extend_fields(sd_bus_message *m, size_t align, size_t sz) {
+        void *op, *np;
+        size_t old_size, new_size, start;
 
-        if (end > (size_t) ((uint32_t) -1))
-                return NULL;
+        assert(m);
 
-        k = realloc(*p, end);
-        if (!k)
+        if (m->poisoned)
                 return NULL;
 
-        /* Zero out padding */
-        if (start > *sz)
-                memset((uint8_t*) k + *sz, 0, start - *sz);
+        old_size = sizeof(struct bus_header) + m->header->fields_size;
+        start = ALIGN_TO(old_size, align);
+        new_size = start + sz;
 
-        *p = k;
-        *sz = end;
+        if (old_size == new_size)
+                return (uint8_t*) m->header + old_size;
 
-        return (uint8_t*) k + start;
-}
+        if (new_size > (size_t) ((uint32_t) -1))
+                goto poison;
 
-static void *message_extend_fields(sd_bus_message *m, size_t align, size_t sz) {
-        void *p, *op;
-        size_t os;
+        if (m->free_header) {
+                np = realloc(m->header, ALIGN8(new_size));
+                if (!np)
+                        goto poison;
+        } else {
+                /* Initially, the header is allocated as part of of
+                 * the sd_bus_message itself, let's replace it by
+                 * dynamic data */
 
-        assert(m);
+                np = malloc(ALIGN8(new_size));
+                if (!np)
+                        goto poison;
 
-        if (m->poisoned)
-                return NULL;
+                memcpy(np, m->header, sizeof(struct bus_header));
+        }
 
-        op = m->fields;
-        os = m->header->fields_size;
+        /* Zero out padding */
+        if (start > old_size)
+                memset((uint8_t*) np + old_size, 0, start - old_size);
 
-        p = buffer_extend(&m->fields, &m->header->fields_size, align, sz);
-        if (!p) {
-                m->poisoned = true;
-                return NULL;
-        }
+        op = m->header;
+        m->header = np;
+        m->header->fields_size = new_size - sizeof(struct bus_header);
 
         /* Adjust quick access pointers */
-        m->path = adjust_pointer(m->path, op, os, m->fields);
-        m->interface = adjust_pointer(m->interface, op, os, m->fields);
-        m->member = adjust_pointer(m->member, op, os, m->fields);
-        m->destination = adjust_pointer(m->destination, op, os, m->fields);
-        m->sender = adjust_pointer(m->sender, op, os, m->fields);
-        m->error.name = adjust_pointer(m->error.name, op, os, m->fields);
+        m->path = adjust_pointer(m->path, op, old_size, m->header);
+        m->interface = adjust_pointer(m->interface, op, old_size, m->header);
+        m->member = adjust_pointer(m->member, op, old_size, m->header);
+        m->destination = adjust_pointer(m->destination, op, old_size, m->header);
+        m->sender = adjust_pointer(m->sender, op, old_size, m->header);
+        m->error.name = adjust_pointer(m->error.name, op, old_size, m->header);
 
-        m->free_fields = true;
+        m->free_header = true;
 
-        return p;
+        return (uint8_t*) np + start;
+
+poison:
+        m->poisoned = true;
+        return NULL;
 }
 
 static int message_append_field_string(
@@ -390,8 +385,6 @@ int bus_message_from_malloc(
                 goto fail;
         }
 
-        m->fields = (uint8_t*) buffer + sizeof(struct bus_header);
-
         m->n_body_parts = 1;
         m->body.data = (uint8_t*) buffer + sizeof(struct bus_header) + ALIGN8(BUS_MESSAGE_FIELDS_SIZE(m));
         m->body.size = length - sizeof(struct bus_header) - ALIGN8(BUS_MESSAGE_FIELDS_SIZE(m));
@@ -3132,7 +3125,7 @@ static int message_peek_fields(
         assert(rindex);
         assert(align > 0);
 
-        return buffer_peek(m->fields, BUS_MESSAGE_FIELDS_SIZE(m), rindex, align, nbytes, ret);
+        return buffer_peek(BUS_MESSAGE_FIELDS(m), BUS_MESSAGE_FIELDS_SIZE(m), rindex, align, nbytes, ret);
 }
 
 static int message_peek_field_uint32(
@@ -3584,21 +3577,13 @@ int bus_message_seal(sd_bus_message *m, uint64_t serial) {
                         return r;
         }
 
+        /* Add padding at the end, since we know the body
+         * needs to start at an 8 byte alignment. */
+
         l = BUS_MESSAGE_FIELDS_SIZE(m);
         a = ALIGN8(l) - l;
-
-        if (a > 0) {
-                /* Add padding at the end, since we know the body
-                 * needs to start at an 8 byte alignment. */
-                void *p;
-
-                p = message_extend_fields(m, 1, a);
-                if (!p)
-                        return -ENOMEM;
-
-                memset(p, 0, a);
-                m->header->fields_size -= a;
-        }
+        if (a > 0)
+                memset((uint8_t*) BUS_MESSAGE_FIELDS(m) + l, 0, a);
 
         MESSAGE_FOREACH_PART(part, i, m)
                 if (part->memfd >= 0 && !part->sealed) {
@@ -3899,15 +3884,7 @@ int bus_message_get_blob(sd_bus_message *m, void **buffer, size_t *sz) {
         if (!p)
                 return -ENOMEM;
 
-        e = mempcpy(p, m->header, sizeof(*m->header));
-
-        if (m->fields) {
-                e = mempcpy(e, m->fields, m->header->fields_size);
-
-                if (m->header->fields_size % 8 != 0)
-                        e = mempset(e, 0, 8 - (m->header->fields_size % 8));
-        }
-
+        e = mempcpy(p, m->header, BUS_MESSAGE_BODY_BEGIN(m));
         MESSAGE_FOREACH_PART(part, i, m)
                 e = mempcpy(e, part->data, part->size);
 
@@ -3981,7 +3958,22 @@ const char* bus_message_get_arg(sd_bus_message *m, unsigned i) {
         return t;
 }
 
-int bus_header_size(struct bus_header *h, size_t *sum) {
+bool bus_header_is_complete(struct bus_header *h, size_t size) {
+        size_t full;
+
+        assert(h);
+        assert(size);
+
+        if (size < sizeof(struct bus_header))
+                return false;
+
+        full = sizeof(struct bus_header) +
+                (h->endian == SD_BUS_NATIVE_ENDIAN ? h->fields_size : bswap_32(h->fields_size));
+
+        return size >= full;
+}
+
+int bus_header_message_size(struct bus_header *h, size_t *sum) {
         size_t fs, bs;
 
         assert(h);
index 01a1e01..2517514 100644 (file)
@@ -89,14 +89,12 @@ struct sd_bus_message {
         bool uid_valid:1;
         bool gid_valid:1;
         bool free_header:1;
-        bool free_fields:1;
         bool free_kdbus:1;
         bool free_fds:1;
         bool release_kdbus:1;
         bool poisoned:1;
 
         struct bus_header *header;
-        void *fields;
         struct bus_body_part body;
         struct bus_body_part *body_end;
         unsigned n_body_parts;
@@ -114,7 +112,7 @@ struct sd_bus_message {
         unsigned n_containers;
 
         struct iovec *iovec;
-        struct iovec iovec_fixed[3];
+        struct iovec iovec_fixed[2];
         unsigned n_iovec;
 
         struct kdbus_msg *kdbus;
@@ -178,6 +176,16 @@ static inline uint32_t BUS_MESSAGE_SIZE(sd_bus_message *m) {
                 BUS_MESSAGE_BODY_SIZE(m);
 }
 
+static inline uint32_t BUS_MESSAGE_BODY_BEGIN(sd_bus_message *m) {
+        return
+                sizeof(struct bus_header) +
+                ALIGN8(BUS_MESSAGE_FIELDS_SIZE(m));
+}
+
+static inline void* BUS_MESSAGE_FIELDS(sd_bus_message *m) {
+        return (uint8_t*) m->header + sizeof(struct bus_header);
+}
+
 static inline void bus_message_unrefp(sd_bus_message **m) {
         sd_bus_message_unref(*m);
 }
@@ -214,7 +222,8 @@ int bus_message_append_ap(sd_bus_message *m, const char *types, va_list ap);
 
 int bus_message_parse_fields(sd_bus_message *m);
 
-int bus_header_size(struct bus_header *h, size_t *sum);
+bool bus_header_is_complete(struct bus_header *h, size_t size);
+int bus_header_message_size(struct bus_header *h, size_t *sum);
 
 struct bus_body_part *message_append_part(sd_bus_message *m);
 
index f43b7da..4635da4 100644 (file)
@@ -83,7 +83,7 @@ static int bus_message_setup_iovec(sd_bus_message *m) {
 
         assert(!m->iovec);
 
-        n = 1 + !!m->fields + m->n_body_parts;
+        n = 1 + m->n_body_parts;
         if (n < ELEMENTSOF(m->iovec_fixed))
                 m->iovec = m->iovec_fixed;
         else {
@@ -92,16 +92,10 @@ static int bus_message_setup_iovec(sd_bus_message *m) {
                         return -ENOMEM;
         }
 
-        r = append_iovec(m, m->header, sizeof(*m->header));
+        r = append_iovec(m, m->header, BUS_MESSAGE_BODY_BEGIN(m));
         if (r < 0)
                 return r;
 
-        if (m->fields) {
-                r = append_iovec(m, m->fields, ALIGN8(m->header->fields_size));
-                if (r < 0)
-                        return r;
-        }
-
         MESSAGE_FOREACH_PART(part, i, m)  {
                 r = append_iovec(m, part->data, part->size);
                 if (r < 0)