chiark / gitweb /
ask-password: properly NULL terminate table
[elogind.git] / src / tty-ask-password-agent.c
index 9c4d076b31ccbd3f74899297e88aafbe4c8e7882..8a6e3d331ede5025e01040917f3c3b8f14ddbf8c 100644 (file)
 #include <sys/inotify.h>
 #include <unistd.h>
 #include <getopt.h>
+#include <sys/signalfd.h>
 
 #include "util.h"
 #include "conf-parser.h"
 #include "utmp-wtmp.h"
+#include "socket-util.h"
+#include "ask-password-api.h"
 
 static enum {
         ACTION_LIST,
@@ -41,7 +44,165 @@ static enum {
         ACTION_WALL
 } arg_action = ACTION_QUERY;
 
-static int parse_password(const char *filename) {
+static bool arg_plymouth = false;
+
+static int ask_password_plymouth(const char *message, usec_t until, const char *flag_file, char **_passphrase) {
+        int fd = -1, notify = -1;
+        union sockaddr_union sa;
+        char *packet = NULL;
+        ssize_t k;
+        int r, n;
+        struct pollfd pollfd[2];
+        char buffer[LINE_MAX];
+        size_t p = 0;
+        enum {
+                POLL_SOCKET,
+                POLL_INOTIFY
+        };
+
+        if (flag_file) {
+                if ((notify = inotify_init1(IN_CLOEXEC|IN_NONBLOCK)) < 0) {
+                        r = -errno;
+                        goto finish;
+                }
+
+                if (inotify_add_watch(notify, flag_file, IN_ATTRIB /* for the link count */) < 0) {
+                        r = -errno;
+                        goto finish;
+                }
+        }
+
+        if ((fd = socket(AF_UNIX, SOCK_STREAM|SOCK_CLOEXEC|SOCK_NONBLOCK, 0)) < 0) {
+                r = -errno;
+                goto finish;
+        }
+
+        zero(sa);
+        sa.sa.sa_family = AF_UNIX;
+        strncpy(sa.un.sun_path+1, "/ply-boot-protocol", sizeof(sa.un.sun_path)-1);
+
+        if (connect(fd, &sa.sa, sizeof(sa.un)) < 0) {
+                r = -errno;
+                goto finish;
+        }
+
+        if (asprintf(&packet, "*\002%c%s%n", (int) (strlen(message) + 1), message, &n) < 0) {
+                r = -ENOMEM;
+                goto finish;
+        }
+
+        if ((k = loop_write(fd, packet, n+1, true)) != n+1) {
+                r = k < 0 ? (int) k : -EIO;
+                goto finish;
+        }
+
+        zero(pollfd);
+        pollfd[POLL_SOCKET].fd = fd;
+        pollfd[POLL_SOCKET].events = POLLIN;
+        pollfd[POLL_INOTIFY].fd = notify;
+        pollfd[POLL_INOTIFY].events = POLLIN;
+
+        for (;;) {
+                int sleep_for = -1, j;
+
+                if (until > 0) {
+                        usec_t y;
+
+                        y = now(CLOCK_MONOTONIC);
+
+                        if (y > until) {
+                                r = -ETIMEDOUT;
+                                goto finish;
+                        }
+
+                        sleep_for = (int) ((until - y) / USEC_PER_MSEC);
+                }
+
+                if (flag_file)
+                        if (access(flag_file, F_OK) < 0) {
+                                r = -errno;
+                                goto finish;
+                        }
+
+                if ((j = poll(pollfd, notify > 0 ? 2 : 1, sleep_for)) < 0) {
+
+                        if (errno == EINTR)
+                                continue;
+
+                        r = -errno;
+                        goto finish;
+                } else if (j == 0) {
+                        r = -ETIMEDOUT;
+                        goto finish;
+                }
+
+                if (notify > 0 && pollfd[POLL_INOTIFY].revents != 0)
+                        flush_fd(notify);
+
+                if (pollfd[POLL_SOCKET].revents == 0)
+                        continue;
+
+                if ((k = read(fd, buffer + p, sizeof(buffer) - p)) <= 0) {
+                        r = k < 0 ? -errno : -EIO;
+                        goto finish;
+                }
+
+                p += k;
+
+                if (p < 1)
+                        continue;
+
+                if (buffer[0] == 5) {
+                        /* No password, because UI not shown */
+                        r = -ENOENT;
+                        goto finish;
+
+                } else if (buffer[0] == 2) {
+                        uint32_t size;
+                        char *s;
+
+                        /* One answer */
+                        if (p < 5)
+                                continue;
+
+                        memcpy(&size, buffer+1, sizeof(size));
+                        if (size+5 > sizeof(buffer)) {
+                                r = -EIO;
+                                goto finish;
+                        }
+
+                        if (p-5 < size)
+                                continue;
+
+                        if (!(s = strndup(buffer + 5, size))) {
+                                r = -ENOMEM;
+                                goto finish;
+                        }
+
+                        *_passphrase = s;
+                        break;
+                } else {
+                        /* Unknown packet */
+                        r = -EIO;
+                        goto finish;
+                }
+        }
+
+        r = 0;
+
+finish:
+        if (notify >= 0)
+                close_nointr_nofail(notify);
+
+        if (fd >= 0)
+                close_nointr_nofail(fd);
+
+        free(packet);
+
+        return r;
+}
+
+static int parse_password(const char *filename, char **wall) {
         char *socket_name = NULL, *message = NULL, *packet = NULL;
         uint64_t not_after = 0;
         unsigned pid = 0;
@@ -52,6 +213,7 @@ static int parse_password(const char *filename) {
                 { "NotAfter", config_parse_uint64,   &not_after,   "Ask" },
                 { "Message",  config_parse_string,   &message,     "Ask" },
                 { "PID",      config_parse_unsigned, &pid,         "Ask" },
+                { NULL, NULL, NULL, NULL }
         };
 
         FILE *f;
@@ -89,11 +251,13 @@ static int parse_password(const char *filename) {
         if (arg_action == ACTION_LIST)
                 printf("'%s' (PID %u)\n", message, pid);
         else if (arg_action == ACTION_WALL) {
-                char *wall;
+                char *_wall;
 
-                if (asprintf(&wall,
-                             "Password entry required for \'%s\' (PID %u).\r\n"
+                if (asprintf(&_wall,
+                             "%s%sPassword entry required for \'%s\' (PID %u).\r\n"
                              "Please enter password with the systemd-tty-password-agent tool!",
+                             *wall ? *wall : "",
+                             *wall ? "\r\n\r\n" : "",
                              message,
                              pid) < 0) {
                         log_error("Out of memory");
@@ -101,8 +265,8 @@ static int parse_password(const char *filename) {
                         goto finish;
                 }
 
-                r = utmp_wall(wall);
-                free(wall);
+                free(*wall);
+                *wall = _wall;
         } else {
                 union {
                         struct sockaddr sa;
@@ -122,8 +286,13 @@ static int parse_password(const char *filename) {
                         goto finish;
                 }
 
-                if ((r = ask_password_tty(message, not_after, filename, &password)) < 0) {
-                        log_error("Failed to query passwords: %s", strerror(-r));
+                if (arg_plymouth)
+                        r = ask_password_plymouth(message, not_after, filename, &password);
+                else
+                        r = ask_password_tty(message, not_after, filename, &password);
+
+                if (r < 0) {
+                        log_error("Failed to query password: %s", strerror(-r));
                         goto finish;
                 }
 
@@ -182,6 +351,7 @@ static int show_passwords(void) {
         while ((de = readdir(d))) {
                 char *p;
                 int q;
+                char *wall;
 
                 if (de->d_type != DT_REG)
                         continue;
@@ -198,10 +368,16 @@ static int show_passwords(void) {
                         goto finish;
                 }
 
-                if ((q = parse_password(p)) < 0)
+                wall = NULL;
+                if ((q = parse_password(p, &wall)) < 0)
                         r = q;
 
                 free(p);
+
+                if (wall) {
+                        utmp_wall(wall);
+                        free(wall);
+                }
         }
 
 finish:
@@ -212,8 +388,15 @@ finish:
 }
 
 static int watch_passwords(void) {
-        int notify;
-        struct pollfd pollfd;
+        enum {
+                FD_INOTIFY,
+                FD_SIGNAL,
+                _FD_MAX
+        };
+
+        int notify = -1, signal_fd = -1;
+        struct pollfd pollfd[_FD_MAX];
+        sigset_t mask;
         int r;
 
         mkdir_p("/dev/.systemd/ask-password", 0755);
@@ -228,15 +411,27 @@ static int watch_passwords(void) {
                 goto finish;
         }
 
+        assert_se(sigemptyset(&mask) == 0);
+        sigset_add_many(&mask, SIGINT, SIGTERM, -1);
+        assert_se(sigprocmask(SIG_SETMASK, &mask, NULL) == 0);
+
+        if ((signal_fd = signalfd(-1, &mask, SFD_NONBLOCK|SFD_CLOEXEC)) < 0) {
+                log_error("signalfd(): %m");
+                r = -errno;
+                goto finish;
+        }
+
         zero(pollfd);
-        pollfd.fd = notify;
-        pollfd.events = POLLIN;
+        pollfd[FD_INOTIFY].fd = notify;
+        pollfd[FD_INOTIFY].events = POLLIN;
+        pollfd[FD_SIGNAL].fd = signal_fd;
+        pollfd[FD_SIGNAL].events = POLLIN;
 
         for (;;) {
                 if ((r = show_passwords()) < 0)
                         break;
 
-                if (poll(&pollfd, 1, -1) < 0) {
+                if (poll(pollfd, _FD_MAX, -1) < 0) {
 
                         if (errno == EINTR)
                                 continue;
@@ -245,8 +440,11 @@ static int watch_passwords(void) {
                         goto finish;
                 }
 
-                if (pollfd.revents != 0)
+                if (pollfd[FD_INOTIFY].revents != 0)
                         flush_fd(notify);
+
+                if (pollfd[FD_SIGNAL].revents != 0)
+                        break;
         }
 
         r = 0;
@@ -255,6 +453,9 @@ finish:
         if (notify >= 0)
                 close_nointr_nofail(notify);
 
+        if (signal_fd >= 0)
+                close_nointr_nofail(signal_fd);
+
         return r;
 }
 
@@ -262,11 +463,12 @@ static int help(void) {
 
         printf("%s [OPTIONS...]\n\n"
                "Process system password requests.\n\n"
-               "  -h --help   Show this help\n"
-               "     --list   Show pending password requests\n"
-               "     --query  Process pending password requests\n"
-               "     --watch  Continously process password requests\n"
-               "     --wall   Continously forward password requests to wall\n",
+               "  -h --help     Show this help\n"
+               "     --list     Show pending password requests\n"
+               "     --query    Process pending password requests\n"
+               "     --watch    Continously process password requests\n"
+               "     --wall     Continously forward password requests to wall\n"
+               "     --plymouth Ask question with Plymouth instead of on TTY\n",
                program_invocation_short_name);
 
         return 0;
@@ -279,15 +481,17 @@ static int parse_argv(int argc, char *argv[]) {
                 ARG_QUERY,
                 ARG_WATCH,
                 ARG_WALL,
+                ARG_PLYMOUTH
         };
 
         static const struct option options[] = {
-                { "help",  no_argument, NULL, 'h'       },
-                { "list",  no_argument, NULL, ARG_LIST  },
-                { "query", no_argument, NULL, ARG_QUERY },
-                { "watch", no_argument, NULL, ARG_WATCH },
-                { "wall",  no_argument, NULL, ARG_WALL  },
-                { NULL,    0,           NULL, 0         }
+                { "help",     no_argument, NULL, 'h'          },
+                { "list",     no_argument, NULL, ARG_LIST     },
+                { "query",    no_argument, NULL, ARG_QUERY    },
+                { "watch",    no_argument, NULL, ARG_WATCH    },
+                { "wall",     no_argument, NULL, ARG_WALL     },
+                { "plymouth", no_argument, NULL, ARG_PLYMOUTH },
+                { NULL,    0,           NULL, 0               }
         };
 
         int c;
@@ -319,6 +523,10 @@ static int parse_argv(int argc, char *argv[]) {
                         arg_action = ACTION_WALL;
                         break;
 
+                case ARG_PLYMOUTH:
+                        arg_plymouth = true;
+                        break;
+
                 case '?':
                         return -EINVAL;