chiark / gitweb /
bus: calculate iovec for messages only when we need it
[elogind.git] / src / libsystemd-bus / bus-socket.c
index ce6af49b266a3c6fb9ef0d827e6a531d7e697010..5e285c9e528ed669c3904f5a91041261391bd92d 100644 (file)
@@ -31,6 +31,7 @@
 #include "missing.h"
 #include "strv.h"
 #include "utf8.h"
+#include "sd-daemon.h"
 
 #include "sd-bus.h"
 #include "bus-socket.h"
@@ -57,6 +58,39 @@ static void iovec_advance(struct iovec iov[], unsigned *idx, size_t size) {
         }
 }
 
+static void append_iovec(sd_bus_message *m, const void *p, size_t sz) {
+        assert(m);
+        assert(p);
+        assert(sz > 0);
+
+        m->iovec[m->n_iovec].iov_base = (void*) p;
+        m->iovec[m->n_iovec].iov_len = sz;
+        m->n_iovec++;
+}
+
+static void bus_message_setup_iovec(sd_bus_message *m) {
+        assert(m);
+        assert(m->sealed);
+
+        if (m->n_iovec > 0)
+                return;
+
+        append_iovec(m, m->header, sizeof(*m->header));
+
+        if (m->fields) {
+                append_iovec(m, m->fields, m->header->fields_size);
+
+                if (m->header->fields_size % 8 != 0) {
+                        static const uint8_t padding[7] = {};
+
+                        append_iovec(m, padding, 8 - (m->header->fields_size % 8));
+                }
+        }
+
+        if (m->body)
+                append_iovec(m, m->body, m->header->body_size);
+}
+
 bool bus_socket_auth_needs_write(sd_bus *b) {
 
         unsigned i;
@@ -234,7 +268,7 @@ static int verify_external_token(sd_bus *b, const char *p, size_t l) {
          * the owner of this bus wanted authentication he should have
          * checked SO_PEERCRED before even creating the bus object. */
 
-        if (!b->ucred_valid)
+        if (!b->anonymous_auth && !b->ucred_valid)
                 return 0;
 
         if (l <= 0)
@@ -257,7 +291,9 @@ static int verify_external_token(sd_bus *b, const char *p, size_t l) {
         if (r < 0)
                 return 0;
 
-        if (u != b->ucred.uid)
+        /* We ignore the passed value if anonymous authentication is
+         * on anyway. */
+        if (!b->anonymous_auth && u != b->ucred.uid)
                 return 0;
 
         return 1;
@@ -310,13 +346,16 @@ static int bus_socket_auth_verify_server(sd_bus *b) {
 
         assert(b);
 
-        if (b->rbuffer_size < 3)
+        if (b->rbuffer_size < 1)
                 return 0;
 
         /* First char must be a NUL byte */
         if (*(char*) b->rbuffer != 0)
                 return -EIO;
 
+        if (b->rbuffer_size < 3)
+                return 0;
+
         /* Begin with the first line */
         if (b->auth_rbegin <= 0)
                 b->auth_rbegin = 1;
@@ -449,7 +488,7 @@ static int bus_socket_read_auth(sd_bus *b) {
         if (r != 0)
                 return r;
 
-        n = MAX(256, b->rbuffer_size * 2);
+        n = MAX(256u, b->rbuffer_size * 2);
 
         if (n > BUS_AUTH_SIZE_MAX)
                 n = BUS_AUTH_SIZE_MAX;
@@ -507,16 +546,23 @@ static int bus_socket_read_auth(sd_bus *b) {
                                    cmsg->cmsg_type == SCM_CREDENTIALS &&
                                    cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
 
-                                memcpy(&b->ucred, CMSG_DATA(cmsg), sizeof(struct ucred));
-                                b->ucred_valid = true;
+                                /* Ignore bogus data, which we might
+                                 * get on socketpair() sockets */
+                                if (((struct ucred*) CMSG_DATA(cmsg))->pid != 0) {
+                                        memcpy(&b->ucred, CMSG_DATA(cmsg), sizeof(struct ucred));
+                                        b->ucred_valid = true;
+                                }
 
                         } else if (cmsg->cmsg_level == SOL_SOCKET &&
                                    cmsg->cmsg_type == SCM_SECURITY) {
 
                                 size_t l;
+
                                 l = cmsg->cmsg_len - CMSG_LEN(0);
-                                memcpy(&b->label, CMSG_DATA(cmsg), l);
-                                b->label[l] = 0;
+                                if (l > 0) {
+                                        memcpy(&b->label, CMSG_DATA(cmsg), l);
+                                        b->label[l] = 0;
+                                }
                         }
                 }
         }
@@ -530,6 +576,7 @@ static int bus_socket_read_auth(sd_bus *b) {
 
 static int bus_socket_setup(sd_bus *b) {
         int enable;
+        socklen_t l;
 
         assert(b);
 
@@ -543,6 +590,11 @@ static int bus_socket_setup(sd_bus *b) {
         fd_inc_rcvbuf(b->input_fd, 1024*1024);
         fd_inc_sndbuf(b->output_fd, 1024*1024);
 
+        /* Get the peer for socketpair() sockets */
+        l = sizeof(b->ucred);
+        if (getsockopt(b->input_fd, SOL_SOCKET, SO_PEERCRED, &b->ucred, &l) >= 0 && l >= sizeof(b->ucred))
+                b->ucred_valid = b->ucred.pid > 0;
+
         return 0;
 }
 
@@ -589,25 +641,17 @@ static int bus_socket_start_auth_client(sd_bus *b) {
 }
 
 static int bus_socket_start_auth(sd_bus *b) {
-        int domain = 0, r;
-        socklen_t sl;
-
         assert(b);
 
         b->state = BUS_AUTHENTICATING;
         b->auth_timeout = now(CLOCK_MONOTONIC) + BUS_DEFAULT_TIMEOUT;
 
-        sl = sizeof(domain);
-        r = getsockopt(b->input_fd, SOL_SOCKET, SO_DOMAIN, &domain, &sl);
-        if (r < 0 || domain != AF_UNIX)
+        if (sd_is_socket(b->input_fd, AF_UNIX, 0, 0) <= 0)
                 b->negotiate_fds = false;
 
-        if (b->output_fd != b->input_fd) {
-                r = getsockopt(b->output_fd, SOL_SOCKET, SO_DOMAIN, &domain, &sl);
-                if (r < 0 || domain != AF_UNIX)
+        if (b->output_fd != b->input_fd)
+                if (sd_is_socket(b->output_fd, AF_UNIX, 0, 0) <= 0)
                         b->negotiate_fds = false;
-        }
-
 
         if (b->is_server)
                 return bus_socket_read_auth(b);
@@ -718,9 +762,11 @@ int bus_socket_write_message(sd_bus *bus, sd_bus_message *m, size_t *idx) {
         assert(idx);
         assert(bus->state == BUS_RUNNING || bus->state == BUS_HELLO);
 
-        if (*idx >= m->size)
+        if (*idx >= bus_message_size(m))
                 return 0;
 
+        bus_message_setup_iovec(m);
+
         n = m->n_iovec * sizeof(struct iovec);
         iov = alloca(n);
         memcpy(iov, m->iovec, n);
@@ -866,7 +912,7 @@ int bus_socket_read_message(sd_bus *bus, sd_bus_message **m) {
                             CMSG_SPACE(NAME_MAX)]; /*selinux label */
         } control;
         struct cmsghdr *cmsg;
-        bool handle_cmsg;
+        bool handle_cmsg = false;
 
         assert(bus);
         assert(m);
@@ -942,16 +988,22 @@ int bus_socket_read_message(sd_bus *bus, sd_bus_message **m) {
                                    cmsg->cmsg_type == SCM_CREDENTIALS &&
                                    cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
 
-                                memcpy(&bus->ucred, CMSG_DATA(cmsg), sizeof(struct ucred));
-                                bus->ucred_valid = true;
+                                /* Ignore bogus data, which we might
+                                 * get on socketpair() sockets */
+                                if (((struct ucred*) CMSG_DATA(cmsg))->pid != 0) {
+                                        memcpy(&bus->ucred, CMSG_DATA(cmsg), sizeof(struct ucred));
+                                        bus->ucred_valid = true;
+                                }
 
                         } else if (cmsg->cmsg_level == SOL_SOCKET &&
                                    cmsg->cmsg_type == SCM_SECURITY) {
 
                                 size_t l;
                                 l = cmsg->cmsg_len - CMSG_LEN(0);
-                                memcpy(&bus->label, CMSG_DATA(cmsg), l);
-                                bus->label[l] = 0;
+                                if (l > 0) {
+                                        memcpy(&bus->label, CMSG_DATA(cmsg), l);
+                                        bus->label[l] = 0;
+                                }
                         }
                 }
         }
@@ -969,16 +1021,14 @@ int bus_socket_read_message(sd_bus *bus, sd_bus_message **m) {
 int bus_socket_process_opening(sd_bus *b) {
         int error = 0;
         socklen_t slen = sizeof(error);
-        struct pollfd p;
+        struct pollfd p = {
+                .fd = b->output_fd,
+                .events = POLLOUT,
+        };
         int r;
 
-        assert(b);
         assert(b->state == BUS_OPENING);
 
-        zero(p);
-        p.fd = b->output_fd;
-        p.events = POLLOUT;
-
         r = poll(&p, 1, 0);
         if (r < 0)
                 return -errno;