chiark / gitweb /
sd-bus: remove ucred parameter from bus_message_from_header() since we don't use...
[elogind.git] / src / libelogind / sd-bus / bus-socket.c
index 873aede65e07e57432c0bebaf4c1d7fa3eec9bf0..9f3756f0c2628b1ea961a0484d06e59afc21f015 100644 (file)
 #include <unistd.h>
 #include <poll.h>
 
+#include "sd-daemon.h"
 #include "util.h"
 #include "macro.h"
 #include "missing.h"
 #include "utf8.h"
-#include "sd-daemon.h"
+#include "formats-util.h"
+#include "signal-util.h"
 
 #include "sd-bus.h"
 #include "bus-socket.h"
@@ -176,7 +178,7 @@ static int bus_socket_auth_verify_client(sd_bus *b) {
         /* We expect two response lines: "OK" and possibly
          * "AGREE_UNIX_FD" */
 
-        e = memmem(b->rbuffer, b->rbuffer_size, "\r\n", 2);
+        e = memmem_safe(b->rbuffer, b->rbuffer_size, "\r\n", 2);
         if (!e)
                 return 0;
 
@@ -491,16 +493,14 @@ static int bus_socket_auth_verify(sd_bus *b) {
 
 static int bus_socket_read_auth(sd_bus *b) {
         struct msghdr mh;
-        struct iovec iov;
+        struct iovec iov = {};
         size_t n;
         ssize_t k;
         int r;
         void *p;
         union {
                 struct cmsghdr cmsghdr;
-                uint8_t buf[CMSG_SPACE(sizeof(int) * BUS_FDS_MAX) +
-                            CMSG_SPACE(sizeof(struct ucred)) +
-                            CMSG_SPACE(NAME_MAX)]; /*selinux label */
+                uint8_t buf[CMSG_SPACE(sizeof(int) * BUS_FDS_MAX)];
         } control;
         struct cmsghdr *cmsg;
         bool handle_cmsg = false;
@@ -526,7 +526,6 @@ static int bus_socket_read_auth(sd_bus *b) {
 
         b->rbuffer = p;
 
-        zero(iov);
         iov.iov_base = (uint8_t*) b->rbuffer + b->rbuffer_size;
         iov.iov_len = n - b->rbuffer_size;
 
@@ -553,8 +552,8 @@ static int bus_socket_read_auth(sd_bus *b) {
 
         b->rbuffer_size += k;
 
-        if (handle_cmsg) {
-                for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg)) {
+        if (handle_cmsg)
+                for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg))
                         if (cmsg->cmsg_level == SOL_SOCKET &&
                             cmsg->cmsg_type == SCM_RIGHTS) {
                                 int j;
@@ -565,31 +564,9 @@ static int bus_socket_read_auth(sd_bus *b) {
                                 j = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
                                 close_many((int*) CMSG_DATA(cmsg), j);
                                 return -EIO;
-
-                        } else if (cmsg->cmsg_level == SOL_SOCKET &&
-                                   cmsg->cmsg_type == SCM_CREDENTIALS &&
-                                   cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
-
-                                /* Ignore bogus data, which we might
-                                 * get on socketpair() sockets */
-                                if (((struct ucred*) CMSG_DATA(cmsg))->pid != 0) {
-                                        memcpy(&b->ucred, CMSG_DATA(cmsg), sizeof(struct ucred));
-                                        b->ucred_valid = true;
-                                }
-
-                        } else if (cmsg->cmsg_level == SOL_SOCKET &&
-                                   cmsg->cmsg_type == SCM_SECURITY) {
-
-                                size_t l;
-
-                                l = cmsg->cmsg_len - CMSG_LEN(0);
-                                if (l > 0) {
-                                        memcpy(&b->label, CMSG_DATA(cmsg), l);
-                                        b->label[l] = 0;
-                                }
-                        }
-                }
-        }
+                        } else
+                                log_debug("Got unexpected auxiliary data with level=%d and type=%d",
+                                          cmsg->cmsg_level, cmsg->cmsg_type);
 
         r = bus_socket_auth_verify(b);
         if (r != 0)
@@ -599,18 +576,8 @@ static int bus_socket_read_auth(sd_bus *b) {
 }
 
 void bus_socket_setup(sd_bus *b) {
-        int enable;
-
         assert(b);
 
-        /* Enable SO_PASSCRED + SO_PASSEC. We try this on any
-         * socket, just in case. */
-        enable = !b->bus_client;
-        (void) setsockopt(b->input_fd, SOL_SOCKET, SO_PASSCRED, &enable, sizeof(enable));
-
-        enable = !b->bus_client && (b->attach_flags & KDBUS_ATTACH_SECLABEL);
-        (void) setsockopt(b->input_fd, SOL_SOCKET, SO_PASSSEC, &enable, sizeof(enable));
-
         /* Increase the buffers to 8 MB */
         fd_inc_rcvbuf(b->input_fd, SNDBUF_SIZE);
         fd_inc_sndbuf(b->output_fd, SNDBUF_SIZE);
@@ -621,10 +588,17 @@ void bus_socket_setup(sd_bus *b) {
 }
 
 static void bus_get_peercred(sd_bus *b) {
+        int r;
+
         assert(b);
 
         /* Get the peer for socketpair() sockets */
         b->ucred_valid = getpeercred(b->input_fd, &b->ucred) >= 0;
+
+        /* Get the SELinux context of the peer */
+        r = getpeersec(b->input_fd, &b->label);
+        if (r < 0 && r != -EOPNOTSUPP)
+                log_debug_errno(r, "Failed to determine peer security context: %m");
 }
 
 static int bus_socket_start_auth_client(sd_bus *b) {
@@ -737,7 +711,8 @@ int bus_socket_exec(sd_bus *b) {
         if (pid == 0) {
                 /* Child */
 
-                reset_all_signal_handlers();
+                (void) reset_all_signal_handlers();
+                (void) reset_signal_mask();
 
                 close_all_fds(s+1, 1);
 
@@ -807,23 +782,21 @@ int bus_socket_write_message(sd_bus *bus, sd_bus_message *m, size_t *idx) {
         if (bus->prefer_writev)
                 k = writev(bus->output_fd, iov, m->n_iovec);
         else {
-                struct msghdr mh;
-                zero(mh);
+                struct msghdr mh = {
+                        .msg_iov = iov,
+                        .msg_iovlen = m->n_iovec,
+                };
 
                 if (m->n_fds > 0) {
                         struct cmsghdr *control;
-                        control = alloca(CMSG_SPACE(sizeof(int) * m->n_fds));
 
-                        mh.msg_control = control;
+                        mh.msg_control = control = alloca(CMSG_SPACE(sizeof(int) * m->n_fds));
+                        mh.msg_controllen = control->cmsg_len = CMSG_LEN(sizeof(int) * m->n_fds);
                         control->cmsg_level = SOL_SOCKET;
                         control->cmsg_type = SCM_RIGHTS;
-                        mh.msg_controllen = control->cmsg_len = CMSG_LEN(sizeof(int) * m->n_fds);
                         memcpy(CMSG_DATA(control), m->fds, sizeof(int) * m->n_fds);
                 }
 
-                mh.msg_iov = iov;
-                mh.msg_iovlen = m->n_iovec;
-
                 k = sendmsg(bus->output_fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL);
                 if (k < 0 && errno == ENOTSOCK) {
                         bus->prefer_writev = true;
@@ -914,8 +887,7 @@ static int bus_socket_make_message(sd_bus *bus, size_t size) {
         r = bus_message_from_malloc(bus,
                                     bus->rbuffer, size,
                                     bus->fds, bus->n_fds,
-                                    !bus->bus_client && bus->ucred_valid ? &bus->ucred : NULL,
-                                    !bus->bus_client && bus->label[0] ? bus->label : NULL,
+                                    NULL,
                                     &t);
         if (r < 0) {
                 free(b);
@@ -935,16 +907,14 @@ static int bus_socket_make_message(sd_bus *bus, size_t size) {
 
 int bus_socket_read_message(sd_bus *bus) {
         struct msghdr mh;
-        struct iovec iov;
+        struct iovec iov = {};
         ssize_t k;
         size_t need;
         int r;
         void *b;
         union {
                 struct cmsghdr cmsghdr;
-                uint8_t buf[CMSG_SPACE(sizeof(int) * BUS_FDS_MAX) +
-                            CMSG_SPACE(sizeof(struct ucred)) +
-                            CMSG_SPACE(NAME_MAX)]; /*selinux label */
+                uint8_t buf[CMSG_SPACE(sizeof(int) * BUS_FDS_MAX)];
         } control;
         struct cmsghdr *cmsg;
         bool handle_cmsg = false;
@@ -965,7 +935,6 @@ int bus_socket_read_message(sd_bus *bus) {
 
         bus->rbuffer = b;
 
-        zero(iov);
         iov.iov_base = (uint8_t*) bus->rbuffer + bus->rbuffer_size;
         iov.iov_len = need - bus->rbuffer_size;
 
@@ -992,8 +961,8 @@ int bus_socket_read_message(sd_bus *bus) {
 
         bus->rbuffer_size += k;
 
-        if (handle_cmsg) {
-                for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg)) {
+        if (handle_cmsg)
+                for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg))
                         if (cmsg->cmsg_level == SOL_SOCKET &&
                             cmsg->cmsg_type == SCM_RIGHTS) {
                                 int n, *f;
@@ -1018,29 +987,9 @@ int bus_socket_read_message(sd_bus *bus) {
                                 memcpy(f + bus->n_fds, CMSG_DATA(cmsg), n * sizeof(int));
                                 bus->fds = f;
                                 bus->n_fds += n;
-                        } else if (cmsg->cmsg_level == SOL_SOCKET &&
-                                   cmsg->cmsg_type == SCM_CREDENTIALS &&
-                                   cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
-
-                                /* Ignore bogus data, which we might
-                                 * get on socketpair() sockets */
-                                if (((struct ucred*) CMSG_DATA(cmsg))->pid != 0) {
-                                        memcpy(&bus->ucred, CMSG_DATA(cmsg), sizeof(struct ucred));
-                                        bus->ucred_valid = true;
-                                }
-
-                        } else if (cmsg->cmsg_level == SOL_SOCKET &&
-                                   cmsg->cmsg_type == SCM_SECURITY) {
-
-                                size_t l;
-                                l = cmsg->cmsg_len - CMSG_LEN(0);
-                                if (l > 0) {
-                                        memcpy(&bus->label, CMSG_DATA(cmsg), l);
-                                        bus->label[l] = 0;
-                                }
-                        }
-                }
-        }
+                        } else
+                                log_debug("Got unexpected auxiliary data with level=%d and type=%d",
+                                          cmsg->cmsg_level, cmsg->cmsg_type);
 
         r = bus_socket_read_message_need(bus, &need);
         if (r < 0)