chiark / gitweb /
nspawn: split out pty forwaring logic into ptyfwd.c
[elogind.git] / src / shared / ptyfwd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2010-2013 Lennart Poettering
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <sys/epoll.h>
23 #include <sys/signalfd.h>
24 #include <sys/ioctl.h>
25 #include <limits.h>
26 #include <termios.h>
27
28 #include "util.h"
29 #include "ptyfwd.h"
30
31 int process_pty(int master, sigset_t *mask, pid_t kill_pid, int signo) {
32         char in_buffer[LINE_MAX], out_buffer[LINE_MAX];
33         size_t in_buffer_full = 0, out_buffer_full = 0;
34         struct epoll_event stdin_ev, stdout_ev, master_ev, signal_ev;
35         bool stdin_readable = false, stdout_writable = false, master_readable = false, master_writable = false;
36         bool tried_orderly_shutdown = false;
37         _cleanup_close_ int ep = -1, signal_fd = -1;
38
39         assert(master >= 0);
40         assert(mask);
41         assert(kill_pid == 0 || kill_pid > 1);
42         assert(signo >= 0 && signo < _NSIG);
43
44         fd_nonblock(STDIN_FILENO, 1);
45         fd_nonblock(STDOUT_FILENO, 1);
46         fd_nonblock(master, 1);
47
48         signal_fd = signalfd(-1, mask, SFD_NONBLOCK|SFD_CLOEXEC);
49         if (signal_fd < 0) {
50                 log_error("signalfd(): %m");
51                 return -errno;
52         }
53
54         ep = epoll_create1(EPOLL_CLOEXEC);
55         if (ep < 0) {
56                 log_error("Failed to create epoll: %m");
57                 return -errno;
58         }
59
60         /* We read from STDIN only if this is actually a TTY,
61          * otherwise we assume non-interactivity. */
62         if (isatty(STDIN_FILENO)) {
63                 zero(stdin_ev);
64                 stdin_ev.events = EPOLLIN|EPOLLET;
65                 stdin_ev.data.fd = STDIN_FILENO;
66
67                 if (epoll_ctl(ep, EPOLL_CTL_ADD, STDIN_FILENO, &stdin_ev) < 0) {
68                         log_error("Failed to register STDIN in epoll: %m");
69                         return -errno;
70                 }
71         }
72
73         zero(stdout_ev);
74         stdout_ev.events = EPOLLOUT|EPOLLET;
75         stdout_ev.data.fd = STDOUT_FILENO;
76
77         zero(master_ev);
78         master_ev.events = EPOLLIN|EPOLLOUT|EPOLLET;
79         master_ev.data.fd = master;
80
81         zero(signal_ev);
82         signal_ev.events = EPOLLIN;
83         signal_ev.data.fd = signal_fd;
84
85         if (epoll_ctl(ep, EPOLL_CTL_ADD, STDOUT_FILENO, &stdout_ev) < 0) {
86                 if (errno != EPERM) {
87                         log_error("Failed to register stdout in epoll: %m");
88                         return -errno;
89                 }
90
91                 /* stdout without epoll support. Likely redirected to regular file. */
92                 stdout_writable = true;
93         }
94
95         if (epoll_ctl(ep, EPOLL_CTL_ADD, master, &master_ev) < 0 ||
96             epoll_ctl(ep, EPOLL_CTL_ADD, signal_fd, &signal_ev) < 0) {
97                 log_error("Failed to register fds in epoll: %m");
98                 return -errno;
99         }
100
101         for (;;) {
102                 struct epoll_event ev[16];
103                 ssize_t k;
104                 int i, nfds;
105
106                 nfds = epoll_wait(ep, ev, ELEMENTSOF(ev), -1);
107                 if (nfds < 0) {
108
109                         if (errno == EINTR || errno == EAGAIN)
110                                 continue;
111
112                         log_error("epoll_wait(): %m");
113                         return -errno;
114                 }
115
116                 assert(nfds >= 1);
117
118                 for (i = 0; i < nfds; i++) {
119                         if (ev[i].data.fd == STDIN_FILENO) {
120
121                                 if (ev[i].events & (EPOLLIN|EPOLLHUP))
122                                         stdin_readable = true;
123
124                         } else if (ev[i].data.fd == STDOUT_FILENO) {
125
126                                 if (ev[i].events & (EPOLLOUT|EPOLLHUP))
127                                         stdout_writable = true;
128
129                         } else if (ev[i].data.fd == master) {
130
131                                 if (ev[i].events & (EPOLLIN|EPOLLHUP))
132                                         master_readable = true;
133
134                                 if (ev[i].events & (EPOLLOUT|EPOLLHUP))
135                                         master_writable = true;
136
137                         } else if (ev[i].data.fd == signal_fd) {
138                                 struct signalfd_siginfo sfsi;
139                                 ssize_t n;
140
141                                 n = read(signal_fd, &sfsi, sizeof(sfsi));
142                                 if (n != sizeof(sfsi)) {
143
144                                         if (n >= 0) {
145                                                 log_error("Failed to read from signalfd: invalid block size");
146                                                 return -EIO;
147                                         }
148
149                                         if (errno != EINTR && errno != EAGAIN) {
150                                                 log_error("Failed to read from signalfd: %m");
151                                                 return -errno;
152                                         }
153                                 } else {
154
155                                         if (sfsi.ssi_signo == SIGWINCH) {
156                                                 struct winsize ws;
157
158                                                 /* The window size changed, let's forward that. */
159                                                 if (ioctl(STDIN_FILENO, TIOCGWINSZ, &ws) >= 0)
160                                                         ioctl(master, TIOCSWINSZ, &ws);
161
162                                         } else if (sfsi.ssi_signo == SIGTERM && kill_pid > 0 && signo > 0 && !tried_orderly_shutdown) {
163
164                                                 if (kill(kill_pid, signo) < 0)
165                                                         return 0;
166
167                                                 log_info("Trying to halt container. Send SIGTERM again to trigger immediate termination.");
168
169                                                 /* This only works for systemd... */
170                                                 tried_orderly_shutdown = true;
171
172                                         } else
173                                                 return 0;
174                                 }
175                         }
176                 }
177
178                 while ((stdin_readable && in_buffer_full <= 0) ||
179                        (master_writable && in_buffer_full > 0) ||
180                        (master_readable && out_buffer_full <= 0) ||
181                        (stdout_writable && out_buffer_full > 0)) {
182
183                         if (stdin_readable && in_buffer_full < LINE_MAX) {
184
185                                 k = read(STDIN_FILENO, in_buffer + in_buffer_full, LINE_MAX - in_buffer_full);
186                                 if (k < 0) {
187
188                                         if (errno == EAGAIN || errno == EPIPE || errno == ECONNRESET || errno == EIO)
189                                                 stdin_readable = false;
190                                         else {
191                                                 log_error("read(): %m");
192                                                 return -errno;
193                                         }
194                                 } else
195                                         in_buffer_full += (size_t) k;
196                         }
197
198                         if (master_writable && in_buffer_full > 0) {
199
200                                 k = write(master, in_buffer, in_buffer_full);
201                                 if (k < 0) {
202
203                                         if (errno == EAGAIN || errno == EPIPE || errno == ECONNRESET || errno == EIO)
204                                                 master_writable = false;
205                                         else {
206                                                 log_error("write(): %m");
207                                                 return -errno;
208                                         }
209
210                                 } else {
211                                         assert(in_buffer_full >= (size_t) k);
212                                         memmove(in_buffer, in_buffer + k, in_buffer_full - k);
213                                         in_buffer_full -= k;
214                                 }
215                         }
216
217                         if (master_readable && out_buffer_full < LINE_MAX) {
218
219                                 k = read(master, out_buffer + out_buffer_full, LINE_MAX - out_buffer_full);
220                                 if (k < 0) {
221
222                                         if (errno == EAGAIN || errno == EPIPE || errno == ECONNRESET || errno == EIO)
223                                                 master_readable = false;
224                                         else {
225                                                 log_error("read(): %m");
226                                                 return -errno;
227                                         }
228                                 }  else
229                                         out_buffer_full += (size_t) k;
230                         }
231
232                         if (stdout_writable && out_buffer_full > 0) {
233
234                                 k = write(STDOUT_FILENO, out_buffer, out_buffer_full);
235                                 if (k < 0) {
236
237                                         if (errno == EAGAIN || errno == EPIPE || errno == ECONNRESET || errno == EIO)
238                                                 stdout_writable = false;
239                                         else {
240                                                 log_error("write(): %m");
241                                                 return -errno;
242                                         }
243
244                                 } else {
245                                         assert(out_buffer_full >= (size_t) k);
246                                         memmove(out_buffer, out_buffer + k, out_buffer_full - k);
247                                         out_buffer_full -= k;
248                                 }
249                         }
250                 }
251         }
252 }