chiark / gitweb /
core: add new logic for services to store file descriptors in PID 1
[elogind.git] / src / core / manager.c
index 519b37438244c0a2bafe363d67c6f4054baf89b7..c18312a369c2a3aa9b93e11f175f4f2056f1e19a 100644 (file)
@@ -84,6 +84,9 @@
 #define JOBS_IN_PROGRESS_PERIOD_USEC (USEC_PER_SEC / 3)
 #define JOBS_IN_PROGRESS_PERIOD_DIVISOR 3
 
+#define NOTIFY_FD_MAX 768
+#define NOTIFY_BUFFER_MAX PIPE_BUF
+
 static int manager_dispatch_notify_fd(sd_event_source *source, int fd, uint32_t revents, void *userdata);
 static int manager_dispatch_signal_fd(sd_event_source *source, int fd, uint32_t revents, void *userdata);
 static int manager_dispatch_time_change_fd(sd_event_source *source, int fd, uint32_t revents, void *userdata);
@@ -1449,7 +1452,7 @@ static unsigned manager_dispatch_dbus_queue(Manager *m) {
         return n;
 }
 
-static void manager_invoke_notify_message(Manager *m, Unit *u, pid_t pid, char *buf, size_t n) {
+static void manager_invoke_notify_message(Manager *m, Unit *u, pid_t pid, char *buf, size_t n, FDSet *fds) {
         _cleanup_strv_free_ char **tags = NULL;
 
         assert(m);
@@ -1466,12 +1469,13 @@ static void manager_invoke_notify_message(Manager *m, Unit *u, pid_t pid, char *
         log_unit_debug(u->id, "Got notification message for unit %s", u->id);
 
         if (UNIT_VTABLE(u)->notify_message)
-                UNIT_VTABLE(u)->notify_message(u, pid, tags);
+                UNIT_VTABLE(u)->notify_message(u, pid, tags, fds);
 }
 
 static int manager_dispatch_notify_fd(sd_event_source *source, int fd, uint32_t revents, void *userdata) {
         Manager *m = userdata;
         ssize_t n;
+        int r;
 
         assert(m);
         assert(m->notify_fd == fd);
@@ -1482,73 +1486,101 @@ static int manager_dispatch_notify_fd(sd_event_source *source, int fd, uint32_t
         }
 
         for (;;) {
-                char buf[4096];
+                _cleanup_fdset_free_ FDSet *fds = NULL;
+                char buf[NOTIFY_BUFFER_MAX+1];
                 struct iovec iovec = {
                         .iov_base = buf,
                         .iov_len = sizeof(buf)-1,
                 };
-                bool found = false;
-
                 union {
                         struct cmsghdr cmsghdr;
-                        uint8_t buf[CMSG_SPACE(sizeof(struct ucred))];
+                        uint8_t buf[CMSG_SPACE(sizeof(struct ucred)) +
+                                    CMSG_SPACE(sizeof(int) * NOTIFY_FD_MAX)];
                 } control = {};
-
                 struct msghdr msghdr = {
                         .msg_iov = &iovec,
                         .msg_iovlen = 1,
                         .msg_control = &control,
                         .msg_controllen = sizeof(control),
                 };
-                struct ucred *ucred;
+                struct cmsghdr *cmsg;
+                struct ucred *ucred = NULL;
+                bool found = false;
                 Unit *u1, *u2, *u3;
+                int *fd_array = NULL;
+                unsigned n_fds = 0;
 
-                n = recvmsg(m->notify_fd, &msghdr, MSG_DONTWAIT);
-                if (n <= 0) {
-                        if (n == 0)
-                                return -EIO;
-
+                n = recvmsg(m->notify_fd, &msghdr, MSG_DONTWAIT|MSG_CMSG_CLOEXEC);
+                if (n < 0) {
                         if (errno == EAGAIN || errno == EINTR)
                                 break;
 
                         return -errno;
                 }
+                if (n == 0)
+                        return -ECONNRESET;
+
+                for (cmsg = CMSG_FIRSTHDR(&msghdr); cmsg; cmsg = CMSG_NXTHDR(&msghdr, cmsg)) {
+                        if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
+
+                                fd_array = (int*) CMSG_DATA(cmsg);
+                                n_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
+
+                        } else if (cmsg->cmsg_level == SOL_SOCKET &&
+                                   cmsg->cmsg_type == SCM_CREDENTIALS &&
+                                   cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
 
-                if (msghdr.msg_controllen < CMSG_LEN(sizeof(struct ucred)) ||
-                    control.cmsghdr.cmsg_level != SOL_SOCKET ||
-                    control.cmsghdr.cmsg_type != SCM_CREDENTIALS ||
-                    control.cmsghdr.cmsg_len != CMSG_LEN(sizeof(struct ucred))) {
-                        log_warning("Received notify message without credentials. Ignoring.");
+                                ucred = (struct ucred*) CMSG_DATA(cmsg);
+                        }
+                }
+
+                if (n_fds > 0) {
+                        assert(fd_array);
+
+                        r = fdset_new_array(&fds, fd_array, n_fds);
+                        if (r < 0) {
+                                close_many(fd_array, n_fds);
+                                return log_oom();
+                        }
+                }
+
+                if (!ucred || ucred->pid <= 0) {
+                        log_warning("Received notify message without valid credentials. Ignoring.");
                         continue;
                 }
 
-                ucred = (struct ucred*) CMSG_DATA(&control.cmsghdr);
+                if ((size_t) n >= sizeof(buf)) {
+                        log_warning("Received notify message exceeded maximum size. Ignoring.");
+                        continue;
+                }
 
-                assert((size_t) n < sizeof(buf));
                 buf[n] = 0;
 
                 /* Notify every unit that might be interested, but try
                  * to avoid notifying the same one multiple times. */
                 u1 = manager_get_unit_by_pid(m, ucred->pid);
                 if (u1) {
-                        manager_invoke_notify_message(m, u1, ucred->pid, buf, n);
+                        manager_invoke_notify_message(m, u1, ucred->pid, buf, n, fds);
                         found = true;
                 }
 
                 u2 = hashmap_get(m->watch_pids1, LONG_TO_PTR(ucred->pid));
                 if (u2 && u2 != u1) {
-                        manager_invoke_notify_message(m, u2, ucred->pid, buf, n);
+                        manager_invoke_notify_message(m, u2, ucred->pid, buf, n, fds);
                         found = true;
                 }
 
                 u3 = hashmap_get(m->watch_pids2, LONG_TO_PTR(ucred->pid));
                 if (u3 && u3 != u2 && u3 != u1) {
-                        manager_invoke_notify_message(m, u3, ucred->pid, buf, n);
+                        manager_invoke_notify_message(m, u3, ucred->pid, buf, n, fds);
                         found = true;
                 }
 
                 if (!found)
                         log_warning("Cannot find unit for notify message of PID "PID_FMT".", ucred->pid);
+
+                if (fdset_size(fds) > 0)
+                        log_warning("Got auxiliary fds with notification message, closing all.");
         }
 
         return 0;