chiark / gitweb /
bus: fall back to readv/writev if recvmsg/sendmsg don't work
authorLennart Poettering <lennart@poettering.net>
Sun, 31 Mar 2013 18:19:18 +0000 (20:19 +0200)
committerLennart Poettering <lennart@poettering.net>
Mon, 1 Apr 2013 01:29:29 +0000 (03:29 +0200)
src/libsystemd-bus/bus-internal.h
src/libsystemd-bus/bus-socket.c
src/systemd/sd-bus.h

index 4997936..feafed0 100644 (file)
@@ -83,6 +83,8 @@ struct sd_bus {
         bool ucred_valid:1;
         bool is_server:1;
         bool anonymous_auth:1;
+        bool prefer_readv:1;
+        bool prefer_writev:1;
 
         void *rbuffer;
         size_t rbuffer_size;
index 9d08674..ce6af49 100644 (file)
@@ -75,7 +75,6 @@ bool bus_socket_auth_needs_write(sd_bus *b) {
 }
 
 static int bus_socket_write_auth(sd_bus *b) {
-        struct msghdr mh;
         ssize_t k;
 
         assert(b);
@@ -84,16 +83,26 @@ static int bus_socket_write_auth(sd_bus *b) {
         if (!bus_socket_auth_needs_write(b))
                 return 0;
 
-        zero(mh);
-        mh.msg_iov = b->auth_iovec + b->auth_index;
-        mh.msg_iovlen = ELEMENTSOF(b->auth_iovec) - b->auth_index;
+        if (b->prefer_writev)
+                k = writev(b->output_fd, b->auth_iovec + b->auth_index, ELEMENTSOF(b->auth_iovec) - b->auth_index);
+        else {
+                struct msghdr mh;
+                zero(mh);
+
+                mh.msg_iov = b->auth_iovec + b->auth_index;
+                mh.msg_iovlen = ELEMENTSOF(b->auth_iovec) - b->auth_index;
+
+                k = sendmsg(b->output_fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL);
+                if (k < 0 && errno == ENOTSOCK) {
+                        b->prefer_writev = true;
+                        k = writev(b->output_fd, b->auth_iovec + b->auth_index, ELEMENTSOF(b->auth_iovec) - b->auth_index);
+                }
+        }
 
-        k = sendmsg(b->output_fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL);
         if (k < 0)
                 return errno == EAGAIN ? 0 : -errno;
 
         iovec_advance(b->auth_iovec, &b->auth_index, (size_t) k);
-
         return 1;
 }
 
@@ -431,6 +440,7 @@ static int bus_socket_read_auth(sd_bus *b) {
                             CMSG_SPACE(NAME_MAX)]; /*selinux label */
         } control;
         struct cmsghdr *cmsg;
+        bool handle_cmsg = false;
 
         assert(b);
         assert(b->state == BUS_AUTHENTICATING);
@@ -457,13 +467,22 @@ static int bus_socket_read_auth(sd_bus *b) {
         iov.iov_base = (uint8_t*) b->rbuffer + b->rbuffer_size;
         iov.iov_len = n - b->rbuffer_size;
 
-        zero(mh);
-        mh.msg_iov = &iov;
-        mh.msg_iovlen = 1;
-        mh.msg_control = &control;
-        mh.msg_controllen = sizeof(control);
-
-        k = recvmsg(b->input_fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL|MSG_CMSG_CLOEXEC);
+        if (b->prefer_readv)
+                k = readv(b->input_fd, &iov, 1);
+        else {
+                zero(mh);
+                mh.msg_iov = &iov;
+                mh.msg_iovlen = 1;
+                mh.msg_control = &control;
+                mh.msg_controllen = sizeof(control);
+
+                k = recvmsg(b->input_fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL|MSG_CMSG_CLOEXEC);
+                if (k < 0 && errno == ENOTSOCK) {
+                        b->prefer_readv = true;
+                        k = readv(b->input_fd, &iov, 1);
+                } else
+                        handle_cmsg = true;
+        }
         if (k < 0)
                 return errno == EAGAIN ? 0 : -errno;
         if (k == 0)
@@ -471,32 +490,34 @@ static int bus_socket_read_auth(sd_bus *b) {
 
         b->rbuffer_size += k;
 
-        for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg)) {
-                if (cmsg->cmsg_level == SOL_SOCKET &&
-                    cmsg->cmsg_type == SCM_RIGHTS) {
-                        int j;
-
-                        /* Whut? We received fds during the auth
-                         * protocol? Somebody is playing games with
-                         * us. Close them all, and fail */
-                        j = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
-                        close_many((int*) CMSG_DATA(cmsg), j);
-                        return -EIO;
+        if (handle_cmsg) {
+                for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg)) {
+                        if (cmsg->cmsg_level == SOL_SOCKET &&
+                            cmsg->cmsg_type == SCM_RIGHTS) {
+                                int j;
+
+                                /* Whut? We received fds during the auth
+                                 * protocol? Somebody is playing games with
+                                 * us. Close them all, and fail */
+                                j = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
+                                close_many((int*) CMSG_DATA(cmsg), j);
+                                return -EIO;
 
-                } else if (cmsg->cmsg_level == SOL_SOCKET &&
-                    cmsg->cmsg_type == SCM_CREDENTIALS &&
-                    cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
+                        } else if (cmsg->cmsg_level == SOL_SOCKET &&
+                                   cmsg->cmsg_type == SCM_CREDENTIALS &&
+                                   cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
 
-                        memcpy(&b->ucred, CMSG_DATA(cmsg), sizeof(struct ucred));
-                        b->ucred_valid = true;
+                                memcpy(&b->ucred, CMSG_DATA(cmsg), sizeof(struct ucred));
+                                b->ucred_valid = true;
 
-                } else if (cmsg->cmsg_level == SOL_SOCKET &&
-                         cmsg->cmsg_type == SCM_SECURITY) {
+                        } else if (cmsg->cmsg_level == SOL_SOCKET &&
+                                   cmsg->cmsg_type == SCM_SECURITY) {
 
-                        size_t l;
-                        l = cmsg->cmsg_len - CMSG_LEN(0);
-                        memcpy(&b->label, CMSG_DATA(cmsg), l);
-                        b->label[l] = 0;
+                                size_t l;
+                                l = cmsg->cmsg_len - CMSG_LEN(0);
+                                memcpy(&b->label, CMSG_DATA(cmsg), l);
+                                b->label[l] = 0;
+                        }
                 }
         }
 
@@ -687,7 +708,6 @@ int bus_socket_take_fd(sd_bus *b) {
 }
 
 int bus_socket_write_message(sd_bus *bus, sd_bus_message *m, size_t *idx) {
-        struct msghdr mh;
         struct iovec *iov;
         ssize_t k;
         size_t n;
@@ -700,18 +720,6 @@ int bus_socket_write_message(sd_bus *bus, sd_bus_message *m, size_t *idx) {
 
         if (*idx >= m->size)
                 return 0;
-        zero(mh);
-
-        if (m->n_fds > 0) {
-                struct cmsghdr *control;
-                control = alloca(CMSG_SPACE(sizeof(int) * m->n_fds));
-
-                mh.msg_control = control;
-                control->cmsg_level = SOL_SOCKET;
-                control->cmsg_type = SCM_RIGHTS;
-                mh.msg_controllen = control->cmsg_len = CMSG_LEN(sizeof(int) * m->n_fds);
-                memcpy(CMSG_DATA(control), m->fds, sizeof(int) * m->n_fds);
-        }
 
         n = m->n_iovec * sizeof(struct iovec);
         iov = alloca(n);
@@ -720,10 +728,33 @@ int bus_socket_write_message(sd_bus *bus, sd_bus_message *m, size_t *idx) {
         j = 0;
         iovec_advance(iov, &j, *idx);
 
-        mh.msg_iov = iov;
-        mh.msg_iovlen = m->n_iovec;
+        if (bus->prefer_writev)
+                k = writev(bus->output_fd, iov, m->n_iovec);
+        else {
+                struct msghdr mh;
+                zero(mh);
+
+                if (m->n_fds > 0) {
+                        struct cmsghdr *control;
+                        control = alloca(CMSG_SPACE(sizeof(int) * m->n_fds));
+
+                        mh.msg_control = control;
+                        control->cmsg_level = SOL_SOCKET;
+                        control->cmsg_type = SCM_RIGHTS;
+                        mh.msg_controllen = control->cmsg_len = CMSG_LEN(sizeof(int) * m->n_fds);
+                        memcpy(CMSG_DATA(control), m->fds, sizeof(int) * m->n_fds);
+                }
+
+                mh.msg_iov = iov;
+                mh.msg_iovlen = m->n_iovec;
+
+                k = sendmsg(bus->output_fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL);
+                if (k < 0 && errno == ENOTSOCK) {
+                        bus->prefer_writev = true;
+                        k = writev(bus->output_fd, iov, m->n_iovec);
+                }
+        }
 
-        k = sendmsg(bus->output_fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL);
         if (k < 0)
                 return errno == EAGAIN ? 0 : -errno;
 
@@ -835,6 +866,7 @@ int bus_socket_read_message(sd_bus *bus, sd_bus_message **m) {
                             CMSG_SPACE(NAME_MAX)]; /*selinux label */
         } control;
         struct cmsghdr *cmsg;
+        bool handle_cmsg;
 
         assert(bus);
         assert(m);
@@ -857,13 +889,22 @@ int bus_socket_read_message(sd_bus *bus, sd_bus_message **m) {
         iov.iov_base = (uint8_t*) bus->rbuffer + bus->rbuffer_size;
         iov.iov_len = need - bus->rbuffer_size;
 
-        zero(mh);
-        mh.msg_iov = &iov;
-        mh.msg_iovlen = 1;
-        mh.msg_control = &control;
-        mh.msg_controllen = sizeof(control);
-
-        k = recvmsg(bus->input_fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL|MSG_CMSG_CLOEXEC);
+        if (bus->prefer_readv)
+                k = readv(bus->input_fd, &iov, 1);
+        else {
+                zero(mh);
+                mh.msg_iov = &iov;
+                mh.msg_iovlen = 1;
+                mh.msg_control = &control;
+                mh.msg_controllen = sizeof(control);
+
+                k = recvmsg(bus->input_fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL|MSG_CMSG_CLOEXEC);
+                if (k < 0 && errno == ENOTSOCK) {
+                        bus->prefer_readv = true;
+                        k = readv(bus->input_fd, &iov, 1);
+                } else
+                        handle_cmsg = true;
+        }
         if (k < 0)
                 return errno == EAGAIN ? 0 : -errno;
         if (k == 0)
@@ -871,45 +912,47 @@ int bus_socket_read_message(sd_bus *bus, sd_bus_message **m) {
 
         bus->rbuffer_size += k;
 
-        for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg)) {
-                if (cmsg->cmsg_level == SOL_SOCKET &&
-                    cmsg->cmsg_type == SCM_RIGHTS) {
-                        int n, *f;
-
-                        n = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
-
-                        if (!bus->can_fds) {
-                                /* Whut? We received fds but this
-                                 * isn't actually enabled? Close them,
-                                 * and fail */
-
-                                close_many((int*) CMSG_DATA(cmsg), n);
-                                return -EIO;
-                        }
-
-                        f = realloc(bus->fds, sizeof(int) + (bus->n_fds + n));
-                        if (!f) {
-                                close_many((int*) CMSG_DATA(cmsg), n);
-                                return -ENOMEM;
+        if (handle_cmsg) {
+                for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg)) {
+                        if (cmsg->cmsg_level == SOL_SOCKET &&
+                            cmsg->cmsg_type == SCM_RIGHTS) {
+                                int n, *f;
+
+                                n = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
+
+                                if (!bus->can_fds) {
+                                        /* Whut? We received fds but this
+                                         * isn't actually enabled? Close them,
+                                         * and fail */
+
+                                        close_many((int*) CMSG_DATA(cmsg), n);
+                                        return -EIO;
+                                }
+
+                                f = realloc(bus->fds, sizeof(int) + (bus->n_fds + n));
+                                if (!f) {
+                                        close_many((int*) CMSG_DATA(cmsg), n);
+                                        return -ENOMEM;
+                                }
+
+                                memcpy(f + bus->n_fds, CMSG_DATA(cmsg), n * sizeof(int));
+                                bus->fds = f;
+                                bus->n_fds += n;
+                        } else if (cmsg->cmsg_level == SOL_SOCKET &&
+                                   cmsg->cmsg_type == SCM_CREDENTIALS &&
+                                   cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
+
+                                memcpy(&bus->ucred, CMSG_DATA(cmsg), sizeof(struct ucred));
+                                bus->ucred_valid = true;
+
+                        } else if (cmsg->cmsg_level == SOL_SOCKET &&
+                                   cmsg->cmsg_type == SCM_SECURITY) {
+
+                                size_t l;
+                                l = cmsg->cmsg_len - CMSG_LEN(0);
+                                memcpy(&bus->label, CMSG_DATA(cmsg), l);
+                                bus->label[l] = 0;
                         }
-
-                        memcpy(f + bus->n_fds, CMSG_DATA(cmsg), n * sizeof(int));
-                        bus->fds = f;
-                        bus->n_fds += n;
-                } else if (cmsg->cmsg_level == SOL_SOCKET &&
-                    cmsg->cmsg_type == SCM_CREDENTIALS &&
-                    cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
-
-                        memcpy(&bus->ucred, CMSG_DATA(cmsg), sizeof(struct ucred));
-                        bus->ucred_valid = true;
-
-                } else if (cmsg->cmsg_level == SOL_SOCKET &&
-                         cmsg->cmsg_type == SCM_SECURITY) {
-
-                        size_t l;
-                        l = cmsg->cmsg_len - CMSG_LEN(0);
-                        memcpy(&bus->label, CMSG_DATA(cmsg), l);
-                        bus->label[l] = 0;
                 }
         }
 
index f792bfa..057931d 100644 (file)
@@ -57,7 +57,7 @@ int sd_bus_open_user(sd_bus **ret);
 
 int sd_bus_new(sd_bus **ret);
 int sd_bus_set_address(sd_bus *bus, const char *address);
-int sd_bus_set_fd(sd_bus *bus, int fd);
+int sd_bus_set_fd(sd_bus *bus, int input_fd, int output_fd);
 int sd_bus_set_exec(sd_bus *bus, const char *path, char *const argv[]);
 int sd_bus_set_bus_client(sd_bus *bus, int b);
 int sd_bus_set_server(sd_bus *bus, int b, sd_id128_t server_id);