chiark / gitweb /
sysctl: always write net.ipv4.conf.all.xyz= in addition to net.ipv4.conf.default...
[elogind.git] / src / shared / ptyfwd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2010-2013 Lennart Poettering
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <sys/epoll.h>
23 #include <sys/signalfd.h>
24 #include <sys/ioctl.h>
25 #include <limits.h>
26 #include <termios.h>
27
28 #include "util.h"
29 #include "ptyfwd.h"
30
31 #define ESCAPE_USEC USEC_PER_SEC
32
33 static bool look_for_escape(usec_t *timestamp, unsigned *counter, const char *buffer, size_t n) {
34         const char *p;
35
36         assert(timestamp);
37         assert(counter);
38         assert(buffer);
39         assert(n > 0);
40
41         for (p = buffer; p < buffer + n; p++) {
42
43                 /* Check for ^] */
44                 if (*p == 0x1D) {
45                         usec_t nw = now(CLOCK_MONOTONIC);
46
47                         if (*counter == 0 || nw > *timestamp + USEC_PER_SEC)  {
48                                 *timestamp = nw;
49                                 *counter = 1;
50                         } else {
51                                 (*counter)++;
52
53                                 if (*counter >= 3)
54                                         return true;
55                         }
56                 } else {
57                         *timestamp = 0;
58                         *counter = 0;
59                 }
60         }
61
62         return false;
63 }
64
65 static int process_pty_loop(int master, sigset_t *mask, pid_t kill_pid, int signo) {
66         char in_buffer[LINE_MAX], out_buffer[LINE_MAX];
67         size_t in_buffer_full = 0, out_buffer_full = 0;
68         struct epoll_event stdin_ev, stdout_ev, master_ev, signal_ev;
69         bool stdin_readable = false, stdout_writable = false, master_readable = false, master_writable = false;
70         bool stdin_hangup = false, stdout_hangup = false, master_hangup = false;
71         bool tried_orderly_shutdown = false, process_signalfd = false, quit = false;
72         usec_t escape_timestamp = 0;
73         unsigned escape_counter = 0;
74         _cleanup_close_ int ep = -1, signal_fd = -1;
75
76         assert(master >= 0);
77         assert(mask);
78         assert(kill_pid == 0 || kill_pid > 1);
79         assert(signo >= 0 && signo < _NSIG);
80
81         fd_nonblock(STDIN_FILENO, true);
82         fd_nonblock(STDOUT_FILENO, true);
83         fd_nonblock(master, true);
84
85         signal_fd = signalfd(-1, mask, SFD_NONBLOCK|SFD_CLOEXEC);
86         if (signal_fd < 0) {
87                 log_error("signalfd(): %m");
88                 return -errno;
89         }
90
91         ep = epoll_create1(EPOLL_CLOEXEC);
92         if (ep < 0) {
93                 log_error("Failed to create epoll: %m");
94                 return -errno;
95         }
96
97         /* We read from STDIN only if this is actually a TTY,
98          * otherwise we assume non-interactivity. */
99         if (isatty(STDIN_FILENO)) {
100                 zero(stdin_ev);
101                 stdin_ev.events = EPOLLIN|EPOLLET;
102                 stdin_ev.data.fd = STDIN_FILENO;
103
104                 if (epoll_ctl(ep, EPOLL_CTL_ADD, STDIN_FILENO, &stdin_ev) < 0) {
105                         log_error("Failed to register STDIN in epoll: %m");
106                         return -errno;
107                 }
108         }
109
110         zero(stdout_ev);
111         stdout_ev.events = EPOLLOUT|EPOLLET;
112         stdout_ev.data.fd = STDOUT_FILENO;
113
114         zero(master_ev);
115         master_ev.events = EPOLLIN|EPOLLOUT|EPOLLET;
116         master_ev.data.fd = master;
117
118         zero(signal_ev);
119         signal_ev.events = EPOLLIN;
120         signal_ev.data.fd = signal_fd;
121
122         if (epoll_ctl(ep, EPOLL_CTL_ADD, STDOUT_FILENO, &stdout_ev) < 0) {
123                 if (errno != EPERM) {
124                         log_error("Failed to register stdout in epoll: %m");
125                         return -errno;
126                 }
127
128                 /* stdout without epoll support. Likely redirected to regular file. */
129                 stdout_writable = true;
130         }
131
132         if (epoll_ctl(ep, EPOLL_CTL_ADD, master, &master_ev) < 0 ||
133             epoll_ctl(ep, EPOLL_CTL_ADD, signal_fd, &signal_ev) < 0) {
134                 log_error("Failed to register fds in epoll: %m");
135                 return -errno;
136         }
137
138         for (;;) {
139                 struct epoll_event ev[16];
140                 ssize_t k;
141                 int i, nfds;
142
143                 nfds = epoll_wait(ep, ev, ELEMENTSOF(ev), quit ? 0 : -1);
144                 if (nfds < 0) {
145
146                         if (errno == EINTR || errno == EAGAIN)
147                                 continue;
148
149                         log_error("epoll_wait(): %m");
150                         return -errno;
151                 }
152
153                 if (nfds == 0)
154                         return 0;
155
156                 for (i = 0; i < nfds; i++) {
157                         if (ev[i].data.fd == STDIN_FILENO) {
158
159                                 if (ev[i].events & (EPOLLIN|EPOLLHUP))
160                                         stdin_readable = true;
161
162                         } else if (ev[i].data.fd == STDOUT_FILENO) {
163
164                                 if (ev[i].events & (EPOLLOUT|EPOLLHUP))
165                                         stdout_writable = true;
166
167                         } else if (ev[i].data.fd == master) {
168
169                                 if (ev[i].events & (EPOLLIN|EPOLLHUP))
170                                         master_readable = true;
171
172                                 if (ev[i].events & (EPOLLOUT|EPOLLHUP))
173                                         master_writable = true;
174
175                         } else if (ev[i].data.fd == signal_fd)
176                                 process_signalfd = true;
177                 }
178
179                 while ((stdin_readable && in_buffer_full <= 0) ||
180                        (master_writable && in_buffer_full > 0) ||
181                        (master_readable && out_buffer_full <= 0) ||
182                        (stdout_writable && out_buffer_full > 0)) {
183
184                         if (stdin_readable && in_buffer_full < LINE_MAX) {
185
186                                 k = read(STDIN_FILENO, in_buffer + in_buffer_full, LINE_MAX - in_buffer_full);
187                                 if (k < 0) {
188
189                                         if (errno == EAGAIN)
190                                                 stdin_readable = false;
191                                         else if (errno == EIO || errno == EPIPE || errno == ECONNRESET) {
192                                                 stdin_readable = false;
193                                                 stdin_hangup = true;
194                                                 epoll_ctl(ep, EPOLL_CTL_DEL, STDIN_FILENO, NULL);
195                                         } else {
196                                                 log_error("read(): %m");
197                                                 return -errno;
198                                         }
199                                 } else {
200                                         /* Check if ^] has been
201                                          * pressed three times within
202                                          * one second. If we get this
203                                          * we quite immediately. */
204                                         if (look_for_escape(&escape_timestamp, &escape_counter, in_buffer + in_buffer_full, k))
205                                                 return !quit;
206
207                                         in_buffer_full += (size_t) k;
208                                 }
209                         }
210
211                         if (master_writable && in_buffer_full > 0) {
212
213                                 k = write(master, in_buffer, in_buffer_full);
214                                 if (k < 0) {
215
216                                         if (errno == EAGAIN || errno == EIO)
217                                                 master_writable = false;
218                                         else if (errno == EPIPE || errno == ECONNRESET) {
219                                                 master_writable = master_readable = false;
220                                                 master_hangup = true;
221                                                 epoll_ctl(ep, EPOLL_CTL_DEL, master, NULL);
222                                         } else {
223                                                 log_error("write(): %m");
224                                                 return -errno;
225                                         }
226
227                                 } else {
228                                         assert(in_buffer_full >= (size_t) k);
229                                         memmove(in_buffer, in_buffer + k, in_buffer_full - k);
230                                         in_buffer_full -= k;
231                                 }
232                         }
233
234                         if (master_readable && out_buffer_full < LINE_MAX) {
235
236                                 k = read(master, out_buffer + out_buffer_full, LINE_MAX - out_buffer_full);
237                                 if (k < 0) {
238
239                                         /* Note that EIO on the master
240                                          * device might be cause by
241                                          * vhangup() or temporary
242                                          * closing of everything on
243                                          * the other side, we treat it
244                                          * like EAGAIN here and try
245                                          * again. */
246
247                                         if (errno == EAGAIN || errno == EIO)
248                                                 master_readable = false;
249                                         else if (errno == EPIPE || errno == ECONNRESET) {
250                                                 master_readable = master_writable = false;
251                                                 master_hangup = true;
252                                                 epoll_ctl(ep, EPOLL_CTL_DEL, master, NULL);
253                                         } else {
254                                                 log_error("read(): %m");
255                                                 return -errno;
256                                         }
257                                 }  else
258                                         out_buffer_full += (size_t) k;
259                         }
260
261                         if (stdout_writable && out_buffer_full > 0) {
262
263                                 k = write(STDOUT_FILENO, out_buffer, out_buffer_full);
264                                 if (k < 0) {
265
266                                         if (errno == EAGAIN)
267                                                 stdout_writable = false;
268                                         else if (errno == EIO || errno == EPIPE || errno == ECONNRESET) {
269                                                 stdout_writable = false;
270                                                 stdout_hangup = true;
271                                                 epoll_ctl(ep, EPOLL_CTL_DEL, STDOUT_FILENO, NULL);
272                                         } else {
273                                                 log_error("write(): %m");
274                                                 return -errno;
275                                         }
276
277                                 } else {
278                                         assert(out_buffer_full >= (size_t) k);
279                                         memmove(out_buffer, out_buffer + k, out_buffer_full - k);
280                                         out_buffer_full -= k;
281                                 }
282                         }
283
284                 }
285
286                 if (process_signalfd) {
287                         struct signalfd_siginfo sfsi;
288                         ssize_t n;
289
290                         n = read(signal_fd, &sfsi, sizeof(sfsi));
291                         if (n != sizeof(sfsi)) {
292
293                                 if (n >= 0) {
294                                         log_error("Failed to read from signalfd: invalid block size");
295                                         return -EIO;
296                                 }
297
298                                 if (errno != EINTR && errno != EAGAIN) {
299                                         log_error("Failed to read from signalfd: %m");
300                                         return -errno;
301                                 }
302                         } else {
303
304                                 if (sfsi.ssi_signo == SIGWINCH) {
305                                         struct winsize ws;
306
307                                         /* The window size changed, let's forward that. */
308                                         if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &ws) >= 0)
309                                                 ioctl(master, TIOCSWINSZ, &ws);
310
311                                 } else if (sfsi.ssi_signo == SIGTERM && kill_pid > 0 && signo > 0 && !tried_orderly_shutdown) {
312
313                                         if (kill(kill_pid, signo) < 0)
314                                                 quit = true;
315                                         else {
316                                                 log_info("Trying to halt container. Send SIGTERM again to trigger immediate termination.");
317
318                                                 /* This only works for systemd... */
319                                                 tried_orderly_shutdown = true;
320                                         }
321
322                                 } else
323                                         /* Signals that where
324                                          * delivered via signalfd that
325                                          * we didn't know are a reason
326                                          * for us to quit */
327                                         quit = true;
328                         }
329                 }
330
331                 if (stdin_hangup || stdout_hangup || master_hangup) {
332                         /* Exit the loop if any side hung up and if
333                          * there's nothing more to write or nothing we
334                          * could write. */
335
336                         if ((out_buffer_full <= 0 || stdout_hangup) &&
337                             (in_buffer_full <= 0 || master_hangup))
338                                 return !quit;
339                 }
340         }
341 }
342
343 int process_pty(int master, sigset_t *mask, pid_t kill_pid, int signo) {
344         struct termios saved_stdin_attr, raw_stdin_attr;
345         struct termios saved_stdout_attr, raw_stdout_attr;
346         bool saved_stdin = false;
347         bool saved_stdout = false;
348         struct winsize ws;
349         int r;
350
351         if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &ws) >= 0)
352                 ioctl(master, TIOCSWINSZ, &ws);
353
354         if (tcgetattr(STDIN_FILENO, &saved_stdin_attr) >= 0) {
355                 saved_stdin = true;
356
357                 raw_stdin_attr = saved_stdin_attr;
358                 cfmakeraw(&raw_stdin_attr);
359                 raw_stdin_attr.c_oflag = saved_stdin_attr.c_oflag;
360                 tcsetattr(STDIN_FILENO, TCSANOW, &raw_stdin_attr);
361         }
362         if (tcgetattr(STDOUT_FILENO, &saved_stdout_attr) >= 0) {
363                 saved_stdout = true;
364
365                 raw_stdout_attr = saved_stdout_attr;
366                 cfmakeraw(&raw_stdout_attr);
367                 raw_stdout_attr.c_iflag = saved_stdout_attr.c_iflag;
368                 raw_stdout_attr.c_lflag = saved_stdout_attr.c_lflag;
369                 tcsetattr(STDOUT_FILENO, TCSANOW, &raw_stdout_attr);
370         }
371
372         r = process_pty_loop(master, mask, kill_pid, signo);
373
374         if (saved_stdout)
375                 tcsetattr(STDOUT_FILENO, TCSANOW, &saved_stdout_attr);
376         if (saved_stdin)
377                 tcsetattr(STDIN_FILENO, TCSANOW, &saved_stdin_attr);
378
379         /* STDIN/STDOUT should not be nonblocking normally, so let's
380          * unconditionally reset it */
381         fd_nonblock(STDIN_FILENO, false);
382         fd_nonblock(STDOUT_FILENO, false);
383
384         return r;
385
386 }