1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
4 This file is part of systemd.
6 Copyright 2010 Lennart Poettering
8 systemd is free software; you can redistribute it and/or modify it
9 under the terms of the GNU Lesser General Public License as published by
10 the Free Software Foundation; either version 2.1 of the License, or
11 (at your option) any later version.
13 systemd is distributed in the hope that it will be useful, but
14 WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 Lesser General Public License for more details.
18 You should have received a copy of the GNU Lesser General Public License
19 along with systemd; If not, see <http://www.gnu.org/licenses/>.
22 #include <sys/socket.h>
24 #include <sys/types.h>
29 #include <sys/epoll.h>
34 #include "socket-util.h"
36 #define BUFFER_SIZE (64*1024)
39 static bool initial_nul = false;
40 static bool auth_over = false;
42 static void format_uid(char *buf, size_t l) {
43 char text[20 + 1]; /* enough space for a 64bit integer plus NUL */
48 snprintf(text, sizeof(text)-1, "%llu", (unsigned long long) geteuid());
49 text[sizeof(text)-1] = 0;
53 for (j = 0; text[j] && j*2+2 < l; j++) {
54 buf[j*2] = hexchar(text[j] >> 4);
55 buf[j*2+1] = hexchar(text[j] & 0xF);
61 static size_t patch_in_line(char *line, size_t l, size_t left) {
64 if (line[0] == 0 && !initial_nul) {
72 if (l == 5 && strneq(line, "BEGIN", 5)) {
76 } else if (l == 17 && strneq(line, "NEGOTIATE_UNIX_FD", 17)) {
77 memmove(line + 13, line + 17, left);
78 memcpy(line, "NEGOTIATE_NOP", 13);
81 } else if (l >= 14 && strneq(line, "AUTH EXTERNAL ", 14)) {
85 format_uid(uid, sizeof(uid));
87 assert(len <= EXTRA_SIZE);
89 memmove(line + 14 + len, line + l, left);
90 memcpy(line + 14, uid, len);
99 static size_t patch_in_buffer(char* in_buffer, size_t *in_buffer_full) {
102 if (*in_buffer_full <= 0)
103 return *in_buffer_full;
105 /* If authentication is done, we don't touch anything anymore */
107 return *in_buffer_full;
109 if (*in_buffer_full < 2)
112 for (i = 0; i <= *in_buffer_full - 2; i ++) {
114 /* Fully lines can be send on */
115 if (in_buffer[i] == '\r' && in_buffer[i+1] == '\n') {
117 size_t old_length, new_length;
119 old_length = i - good;
120 new_length = patch_in_line(in_buffer+good, old_length, *in_buffer_full - i);
121 *in_buffer_full = *in_buffer_full + new_length - old_length;
123 good += new_length + 2;
136 int main(int argc, char *argv[]) {
137 int r = EXIT_FAILURE, fd = -1, ep = -1;
138 union sockaddr_union sa;
139 char in_buffer[BUFFER_SIZE+EXTRA_SIZE], out_buffer[BUFFER_SIZE+EXTRA_SIZE];
140 size_t in_buffer_full = 0, out_buffer_full = 0;
141 struct epoll_event stdin_ev, stdout_ev, fd_ev;
142 bool stdin_readable = false, stdout_writable = false, fd_readable = false, fd_writable = false;
143 bool stdin_rhup = false, stdout_whup = false, fd_rhup = false, fd_whup = false;
146 log_error("This program takes no argument.");
150 log_set_target(LOG_TARGET_JOURNAL_OR_KMSG);
151 log_parse_environment();
154 if ((fd = socket(AF_UNIX, SOCK_STREAM|SOCK_CLOEXEC|SOCK_NONBLOCK, 0)) < 0) {
155 log_error("Failed to create socket: %s", strerror(errno));
160 sa.un.sun_family = AF_UNIX;
161 strncpy(sa.un.sun_path, "/run/dbus/system_bus_socket", sizeof(sa.un.sun_path));
163 if (connect(fd, &sa.sa, offsetof(struct sockaddr_un, sun_path) + strlen(sa.un.sun_path)) < 0) {
164 log_error("Failed to connect: %m");
168 fd_nonblock(STDIN_FILENO, 1);
169 fd_nonblock(STDOUT_FILENO, 1);
171 if ((ep = epoll_create1(EPOLL_CLOEXEC)) < 0) {
172 log_error("Failed to create epoll: %m");
177 stdin_ev.events = EPOLLIN|EPOLLET;
178 stdin_ev.data.fd = STDIN_FILENO;
181 stdout_ev.events = EPOLLOUT|EPOLLET;
182 stdout_ev.data.fd = STDOUT_FILENO;
185 fd_ev.events = EPOLLIN|EPOLLOUT|EPOLLET;
188 if (epoll_ctl(ep, EPOLL_CTL_ADD, STDIN_FILENO, &stdin_ev) < 0 ||
189 epoll_ctl(ep, EPOLL_CTL_ADD, STDOUT_FILENO, &stdout_ev) < 0 ||
190 epoll_ctl(ep, EPOLL_CTL_ADD, fd, &fd_ev) < 0) {
191 log_error("Failed to regiser fds in epoll: %m");
196 struct epoll_event ev[16];
200 if ((nfds = epoll_wait(ep, ev, ELEMENTSOF(ev), -1)) < 0) {
202 if (errno == EINTR || errno == EAGAIN)
205 log_error("epoll_wait(): %m");
211 for (i = 0; i < nfds; i++) {
212 if (ev[i].data.fd == STDIN_FILENO) {
214 if (!stdin_rhup && (ev[i].events & (EPOLLHUP|EPOLLIN)))
215 stdin_readable = true;
217 } else if (ev[i].data.fd == STDOUT_FILENO) {
219 if (ev[i].events & EPOLLHUP) {
220 stdout_writable = false;
224 if (!stdout_whup && (ev[i].events & EPOLLOUT))
225 stdout_writable = true;
227 } else if (ev[i].data.fd == fd) {
229 if (ev[i].events & EPOLLHUP) {
234 if (!fd_rhup && (ev[i].events & (EPOLLHUP|EPOLLIN)))
237 if (!fd_whup && (ev[i].events & EPOLLOUT))
242 while ((stdin_readable && in_buffer_full <= 0) ||
243 (fd_writable && patch_in_buffer(in_buffer, &in_buffer_full) > 0) ||
244 (fd_readable && out_buffer_full <= 0) ||
245 (stdout_writable && out_buffer_full > 0)) {
247 size_t in_buffer_good = 0;
249 if (stdin_readable && in_buffer_full < BUFFER_SIZE) {
251 if ((k = read(STDIN_FILENO, in_buffer + in_buffer_full, BUFFER_SIZE - in_buffer_full)) < 0) {
254 stdin_readable = false;
255 else if (errno == EPIPE || errno == ECONNRESET)
258 log_error("read(): %m");
262 in_buffer_full += (size_t) k;
266 stdin_readable = false;
267 shutdown(STDIN_FILENO, SHUT_RD);
268 close_nointr_nofail(STDIN_FILENO);
272 in_buffer_good = patch_in_buffer(in_buffer, &in_buffer_full);
274 if (fd_writable && in_buffer_good > 0) {
276 if ((k = write(fd, in_buffer, in_buffer_good)) < 0) {
280 else if (errno == EPIPE || errno == ECONNRESET) {
283 shutdown(fd, SHUT_WR);
285 log_error("write(): %m");
290 assert(in_buffer_full >= (size_t) k);
291 memmove(in_buffer, in_buffer + k, in_buffer_full - k);
296 if (fd_readable && out_buffer_full < BUFFER_SIZE) {
298 if ((k = read(fd, out_buffer + out_buffer_full, BUFFER_SIZE - out_buffer_full)) < 0) {
302 else if (errno == EPIPE || errno == ECONNRESET)
305 log_error("read(): %m");
309 out_buffer_full += (size_t) k;
314 shutdown(fd, SHUT_RD);
318 if (stdout_writable && out_buffer_full > 0) {
320 if ((k = write(STDOUT_FILENO, out_buffer, out_buffer_full)) < 0) {
323 stdout_writable = false;
324 else if (errno == EPIPE || errno == ECONNRESET) {
326 stdout_writable = false;
327 shutdown(STDOUT_FILENO, SHUT_WR);
328 close_nointr(STDOUT_FILENO);
330 log_error("write(): %m");
335 assert(out_buffer_full >= (size_t) k);
336 memmove(out_buffer, out_buffer + k, out_buffer_full - k);
337 out_buffer_full -= k;
342 if (stdin_rhup && in_buffer_full <= 0 && !fd_whup) {
345 shutdown(fd, SHUT_WR);
348 if (fd_rhup && out_buffer_full <= 0 && !stdout_whup) {
350 stdout_writable = false;
351 shutdown(STDOUT_FILENO, SHUT_WR);
352 close_nointr(STDOUT_FILENO);
355 } while (!stdout_whup || !fd_whup);
361 close_nointr_nofail(fd);
364 close_nointr_nofail(ep);