chiark / gitweb /
sd-daemon: simplify sd_pid_notify_with_fds
[elogind.git] / src / libsystemd / sd-daemon / sd-daemon.c
index 028c2a7a5b4cceb881207ee23fa2bff511196933..82ac72c72a0f8d48b583a1272722a237c296a6ae 100644 (file)
   along with systemd; If not, see <http://www.gnu.org/licenses/>.
 ***/
 
   along with systemd; If not, see <http://www.gnu.org/licenses/>.
 ***/
 
-#include <sys/types.h>
 #include <sys/stat.h>
 #include <sys/socket.h>
 #include <sys/un.h>
 #include <sys/stat.h>
 #include <sys/socket.h>
 #include <sys/un.h>
-#include <fcntl.h>
 #include <netinet/in.h>
 #include <stdlib.h>
 #include <errno.h>
 #include <netinet/in.h>
 #include <stdlib.h>
 #include <errno.h>
@@ -352,16 +350,10 @@ _public_ int sd_pid_notify_with_fds(pid_t pid, int unset_environment, const char
                 .msg_iovlen = 1,
                 .msg_name = &sockaddr,
         };
                 .msg_iovlen = 1,
                 .msg_name = &sockaddr,
         };
-        union {
-                struct cmsghdr cmsghdr;
-                uint8_t buf[CMSG_SPACE(sizeof(struct ucred)) +
-                            CMSG_SPACE(sizeof(int) * n_fds)];
-        } control;
         _cleanup_close_ int fd = -1;
         struct cmsghdr *cmsg = NULL;
         const char *e;
         _cleanup_close_ int fd = -1;
         struct cmsghdr *cmsg = NULL;
         const char *e;
-        size_t controllen_without_ucred = 0;
-        bool try_without_ucred = false;
+        bool have_pid;
         int r;
 
         if (!state) {
         int r;
 
         if (!state) {
@@ -400,40 +392,37 @@ _public_ int sd_pid_notify_with_fds(pid_t pid, int unset_environment, const char
         if (msghdr.msg_namelen > sizeof(struct sockaddr_un))
                 msghdr.msg_namelen = sizeof(struct sockaddr_un);
 
         if (msghdr.msg_namelen > sizeof(struct sockaddr_un))
                 msghdr.msg_namelen = sizeof(struct sockaddr_un);
 
-        if (n_fds > 0) {
-                msghdr.msg_control = &control;
-                msghdr.msg_controllen = CMSG_LEN(sizeof(int) * n_fds);
+        have_pid = pid != 0 && pid != getpid();
 
 
-                cmsg = CMSG_FIRSTHDR(&msghdr);
-                cmsg->cmsg_level = SOL_SOCKET;
-                cmsg->cmsg_type = SCM_RIGHTS;
-                cmsg->cmsg_len = CMSG_LEN(sizeof(int) * n_fds);
-
-                memcpy(CMSG_DATA(cmsg), fds, sizeof(int) * n_fds);
-        }
+        if (n_fds > 0 || have_pid) {
+                msghdr.msg_controllen = CMSG_SPACE(sizeof(int) * n_fds) +
+                                        CMSG_SPACE(sizeof(struct ucred) * have_pid);
+                msghdr.msg_control = alloca(msghdr.msg_controllen);
 
 
-        if (pid != 0 && pid != getpid()) {
-                struct ucred *ucred;
+                cmsg = CMSG_FIRSTHDR(&msghdr);
+                if (n_fds > 0) {
+                        cmsg->cmsg_level = SOL_SOCKET;
+                        cmsg->cmsg_type = SCM_RIGHTS;
+                        cmsg->cmsg_len = CMSG_LEN(sizeof(int) * n_fds);
 
 
-                try_without_ucred = true;
-                controllen_without_ucred = msghdr.msg_controllen;
+                        memcpy(CMSG_DATA(cmsg), fds, sizeof(int) * n_fds);
 
 
-                msghdr.msg_control = &control;
-                msghdr.msg_controllen += CMSG_LEN(sizeof(struct ucred));
+                        if (have_pid)
+                                assert_se(cmsg = CMSG_NXTHDR(&msghdr, cmsg));
+                }
 
 
-                if (cmsg)
-                        cmsg = CMSG_NXTHDR(&msghdr, cmsg);
-                else
-                        cmsg = CMSG_FIRSTHDR(&msghdr);
+                if (have_pid) {
+                        struct ucred *ucred;
 
 
-                cmsg->cmsg_level = SOL_SOCKET;
-                cmsg->cmsg_type = SCM_CREDENTIALS;
-                cmsg->cmsg_len = CMSG_LEN(sizeof(struct ucred));
+                        cmsg->cmsg_level = SOL_SOCKET;
+                        cmsg->cmsg_type = SCM_CREDENTIALS;
+                        cmsg->cmsg_len = CMSG_LEN(sizeof(struct ucred));
 
 
-                ucred = (struct ucred*) CMSG_DATA(cmsg);
-                ucred->pid = pid;
-                ucred->uid = getuid();
-                ucred->gid = getgid();
+                        ucred = (struct ucred*) CMSG_DATA(cmsg);
+                        ucred->pid = pid;
+                        ucred->uid = getuid();
+                        ucred->gid = getgid();
+                }
         }
 
         /* First try with fake ucred data, as requested */
         }
 
         /* First try with fake ucred data, as requested */
@@ -443,10 +432,10 @@ _public_ int sd_pid_notify_with_fds(pid_t pid, int unset_environment, const char
         }
 
         /* If that failed, try with our own ucred instead */
         }
 
         /* If that failed, try with our own ucred instead */
-        if (try_without_ucred) {
-                if (controllen_without_ucred <= 0)
+        if (have_pid) {
+                msghdr.msg_controllen -= CMSG_SPACE(sizeof(struct ucred));
+                if (msghdr.msg_controllen == 0)
                         msghdr.msg_control = NULL;
                         msghdr.msg_control = NULL;
-                msghdr.msg_controllen = controllen_without_ucred;
 
                 if (sendmsg(fd, &msghdr, MSG_NOSIGNAL) >= 0) {
                         r = 1;
 
                 if (sendmsg(fd, &msghdr, MSG_NOSIGNAL) >= 0) {
                         r = 1;