chiark / gitweb /
ask-password: supported plymouth cached passwords
[elogind.git] / src / tty-ask-password-agent.c
index 655bfb9..14b0148 100644 (file)
@@ -37,6 +37,7 @@
 #include "utmp-wtmp.h"
 #include "socket-util.h"
 #include "ask-password-api.h"
+#include "strv.h"
 
 static enum {
         ACTION_LIST,
@@ -48,7 +49,13 @@ static enum {
 static bool arg_plymouth = false;
 static bool arg_console = false;
 
-static int ask_password_plymouth(const char *message, usec_t until, const char *flag_file, char **_passphrase) {
+static int ask_password_plymouth(
+                const char *message,
+                usec_t until,
+                const char *flag_file,
+                bool accept_cached,
+                char ***_passphrases) {
+
         int fd = -1, notify = -1;
         union sockaddr_union sa;
         char *packet = NULL;
@@ -62,6 +69,8 @@ static int ask_password_plymouth(const char *message, usec_t until, const char *
                 POLL_INOTIFY
         };
 
+        assert(_passphrases);
+
         if (flag_file) {
                 if ((notify = inotify_init1(IN_CLOEXEC|IN_NONBLOCK)) < 0) {
                         r = -errno;
@@ -88,7 +97,13 @@ static int ask_password_plymouth(const char *message, usec_t until, const char *
                 goto finish;
         }
 
-        if (asprintf(&packet, "*\002%c%s%n", (int) (strlen(message) + 1), message, &n) < 0) {
+        if (accept_cached) {
+                packet = strdup("c");
+                n = 1;
+        } else
+                asprintf(&packet, "*\002%c%s%n", (int) (strlen(message) + 1), message, &n);
+
+        if (!packet) {
                 r = -ENOMEM;
                 goto finish;
         }
@@ -155,15 +170,38 @@ static int ask_password_plymouth(const char *message, usec_t until, const char *
                         continue;
 
                 if (buffer[0] == 5) {
+
+                        if (accept_cached) {
+                                /* Hmm, first try with cached
+                                 * passwords failed, so let's retry
+                                 * with a normal password request */
+                                free(packet);
+                                packet = NULL;
+
+                                if (asprintf(&packet, "*\002%c%s%n", (int) (strlen(message) + 1), message, &n) < 0) {
+                                        r = -ENOMEM;
+                                        goto finish;
+                                }
+
+                                if ((k = loop_write(fd, packet, n+1, true)) != n+1) {
+                                        r = k < 0 ? (int) k : -EIO;
+                                        goto finish;
+                                }
+
+                                accept_cached = false;
+                                p = 0;
+                                continue;
+                        }
+
                         /* No password, because UI not shown */
                         r = -ENOENT;
                         goto finish;
 
-                } else if (buffer[0] == 2) {
+                } else if (buffer[0] == 2 || buffer[0] == 9) {
                         uint32_t size;
-                        char *s;
+                        char **l;
 
-                        /* One answer */
+                        /* One ore more answers */
                         if (p < 5)
                                 continue;
 
@@ -176,13 +214,14 @@ static int ask_password_plymouth(const char *message, usec_t until, const char *
                         if (p-5 < size)
                                 continue;
 
-                        if (!(s = strndup(buffer + 5, size))) {
+                        if (!(l = strv_parse_nulstr(buffer + 5, size))) {
                                 r = -ENOMEM;
                                 goto finish;
                         }
 
-                        *_passphrase = s;
+                        *_passphrases = l;
                         break;
+
                 } else {
                         /* Unknown packet */
                         r = -EIO;
@@ -209,12 +248,14 @@ static int parse_password(const char *filename, char **wall) {
         uint64_t not_after = 0;
         unsigned pid = 0;
         int socket_fd = -1;
+        bool accept_cached = false;
 
         const ConfigItem items[] = {
-                { "Socket",   config_parse_string,   &socket_name, "Ask" },
-                { "NotAfter", config_parse_uint64,   &not_after,   "Ask" },
-                { "Message",  config_parse_string,   &message,     "Ask" },
-                { "PID",      config_parse_unsigned, &pid,         "Ask" },
+                { "Socket",       config_parse_string,   &socket_name,   "Ask" },
+                { "NotAfter",     config_parse_uint64,   &not_after,     "Ask" },
+                { "Message",      config_parse_string,   &message,       "Ask" },
+                { "PID",          config_parse_unsigned, &pid,           "Ask" },
+                { "AcceptCached", config_parse_bool,     &accept_cached, "Ask" },
                 { NULL, NULL, NULL, NULL }
         };
 
@@ -274,7 +315,7 @@ static int parse_password(const char *filename, char **wall) {
                         struct sockaddr sa;
                         struct sockaddr_un un;
                 } sa;
-                char *password;
+                size_t packet_length;
 
                 assert(arg_action == ACTION_QUERY ||
                        arg_action == ACTION_WATCH);
@@ -288,10 +329,32 @@ static int parse_password(const char *filename, char **wall) {
                         goto finish;
                 }
 
-                if (arg_plymouth)
-                        r = ask_password_plymouth(message, not_after, filename, &password);
-                else {
+                if (arg_plymouth) {
+                        char **passwords;
+
+                        if ((r = ask_password_plymouth(message, not_after, filename, accept_cached, &passwords)) >= 0) {
+                                char **p;
+
+                                packet_length = 1;
+                                STRV_FOREACH(p, passwords)
+                                        packet_length += strlen(*p) + 1;
+
+                                if (!(packet = new(char, packet_length)))
+                                        r = -ENOMEM;
+                                else {
+                                        char *d;
+
+                                        packet[0] = '+';
+                                        d = packet+1;
+
+                                        STRV_FOREACH(p, passwords)
+                                                d = stpcpy(d, *p) + 1;
+                                }
+                        }
+
+                } else {
                         int tty_fd = -1;
+                        char *password;
 
                         if (arg_console)
                                 if ((tty_fd = acquire_terminal("/dev/console", false, false, false)) < 0) {
@@ -305,6 +368,11 @@ static int parse_password(const char *filename, char **wall) {
                                 close_nointr_nofail(tty_fd);
                                 release_terminal();
                         }
+
+                        asprintf(&packet, "+%s", password);
+                        free(password);
+
+                        packet_length = strlen(packet);
                 }
 
                 if (r < 0) {
@@ -312,9 +380,6 @@ static int parse_password(const char *filename, char **wall) {
                         goto finish;
                 }
 
-                asprintf(&packet, "+%s", password);
-                free(password);
-
                 if (!packet) {
                         log_error("Out of memory");
                         r = -ENOMEM;
@@ -331,7 +396,7 @@ static int parse_password(const char *filename, char **wall) {
                 sa.un.sun_family = AF_UNIX;
                 strncpy(sa.un.sun_path, socket_name, sizeof(sa.un.sun_path));
 
-                if (sendto(socket_fd, packet, strlen(packet), MSG_NOSIGNAL, &sa.sa, offsetof(struct sockaddr_un, sun_path) + strlen(socket_name)) < 0) {
+                if (sendto(socket_fd, packet, packet_length, MSG_NOSIGNAL, &sa.sa, offsetof(struct sockaddr_un, sun_path) + strlen(socket_name)) < 0) {
                         log_error("Failed to send: %m");
                         r = -errno;
                         goto finish;