chiark / gitweb /
ptyfwd: simplify how we handle vhangups a bit
[elogind.git] / src / shared / ptyfwd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2010-2013 Lennart Poettering
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <sys/epoll.h>
23 #include <sys/signalfd.h>
24 #include <sys/ioctl.h>
25 #include <limits.h>
26 #include <termios.h>
27
28 #include "util.h"
29 #include "ptyfwd.h"
30
31 struct PTYForward {
32         sd_event *event;
33
34         int master;
35
36         sd_event_source *stdin_event_source;
37         sd_event_source *stdout_event_source;
38         sd_event_source *master_event_source;
39
40         sd_event_source *sigwinch_event_source;
41
42         struct termios saved_stdin_attr;
43         struct termios saved_stdout_attr;
44
45         bool saved_stdin:1;
46         bool saved_stdout:1;
47
48         bool stdin_readable:1;
49         bool stdin_hangup:1;
50         bool stdout_writable:1;
51         bool stdout_hangup:1;
52         bool master_readable:1;
53         bool master_writable:1;
54         bool master_hangup:1;
55
56         /* Continue reading after hangup? */
57         bool ignore_vhangup:1;
58
59         bool last_char_set:1;
60         char last_char;
61
62         char in_buffer[LINE_MAX], out_buffer[LINE_MAX];
63         size_t in_buffer_full, out_buffer_full;
64
65         usec_t escape_timestamp;
66         unsigned escape_counter;
67 };
68
69 #define ESCAPE_USEC (1*USEC_PER_SEC)
70
71 static bool look_for_escape(PTYForward *f, const char *buffer, size_t n) {
72         const char *p;
73
74         assert(f);
75         assert(buffer);
76         assert(n > 0);
77
78         for (p = buffer; p < buffer + n; p++) {
79
80                 /* Check for ^] */
81                 if (*p == 0x1D) {
82                         usec_t nw = now(CLOCK_MONOTONIC);
83
84                         if (f->escape_counter == 0 || nw > f->escape_timestamp + ESCAPE_USEC)  {
85                                 f->escape_timestamp = nw;
86                                 f->escape_counter = 1;
87                         } else {
88                                 (f->escape_counter)++;
89
90                                 if (f->escape_counter >= 3)
91                                         return true;
92                         }
93                 } else {
94                         f->escape_timestamp = 0;
95                         f->escape_counter = 0;
96                 }
97         }
98
99         return false;
100 }
101
102 static int shovel(PTYForward *f) {
103         ssize_t k;
104
105         assert(f);
106
107         while ((f->stdin_readable && f->in_buffer_full <= 0) ||
108                (f->master_writable && f->in_buffer_full > 0) ||
109                (f->master_readable && f->out_buffer_full <= 0) ||
110                (f->stdout_writable && f->out_buffer_full > 0)) {
111
112                 if (f->stdin_readable && f->in_buffer_full < LINE_MAX) {
113
114                         k = read(STDIN_FILENO, f->in_buffer + f->in_buffer_full, LINE_MAX - f->in_buffer_full);
115                         if (k < 0) {
116
117                                 if (errno == EAGAIN)
118                                         f->stdin_readable = false;
119                                 else if (errno == EIO || errno == EPIPE || errno == ECONNRESET) {
120                                         f->stdin_readable = false;
121                                         f->stdin_hangup = true;
122
123                                         f->stdin_event_source = sd_event_source_unref(f->stdin_event_source);
124                                 } else {
125                                         log_error_errno(errno, "read(): %m");
126                                         return sd_event_exit(f->event, EXIT_FAILURE);
127                                 }
128                         } else if (k == 0) {
129                                 /* EOF on stdin */
130                                 f->stdin_readable = false;
131                                 f->stdin_hangup = true;
132
133                                 f->stdin_event_source = sd_event_source_unref(f->stdin_event_source);
134                         } else  {
135                                 /* Check if ^] has been
136                                  * pressed three times within
137                                  * one second. If we get this
138                                  * we quite immediately. */
139                                 if (look_for_escape(f, f->in_buffer + f->in_buffer_full, k))
140                                         return sd_event_exit(f->event, EXIT_FAILURE);
141
142                                 f->in_buffer_full += (size_t) k;
143                         }
144                 }
145
146                 if (f->master_writable && f->in_buffer_full > 0) {
147
148                         k = write(f->master, f->in_buffer, f->in_buffer_full);
149                         if (k < 0) {
150
151                                 if (errno == EAGAIN || errno == EIO)
152                                         f->master_writable = false;
153                                 else if (errno == EPIPE || errno == ECONNRESET) {
154                                         f->master_writable = f->master_readable = false;
155                                         f->master_hangup = true;
156
157                                         f->master_event_source = sd_event_source_unref(f->master_event_source);
158                                 } else {
159                                         log_error_errno(errno, "write(): %m");
160                                         return sd_event_exit(f->event, EXIT_FAILURE);
161                                 }
162                         } else {
163                                 assert(f->in_buffer_full >= (size_t) k);
164                                 memmove(f->in_buffer, f->in_buffer + k, f->in_buffer_full - k);
165                                 f->in_buffer_full -= k;
166                         }
167                 }
168
169                 if (f->master_readable && f->out_buffer_full < LINE_MAX) {
170
171                         k = read(f->master, f->out_buffer + f->out_buffer_full, LINE_MAX - f->out_buffer_full);
172                         if (k < 0) {
173
174                                 /* Note that EIO on the master device
175                                  * might be caused by vhangup() or
176                                  * temporary closing of everything on
177                                  * the other side, we treat it like
178                                  * EAGAIN here and try again, unless
179                                  * ignore_vhangup is off. */
180
181                                 if (errno == EAGAIN || (errno == EIO && f->ignore_vhangup))
182                                         f->master_readable = false;
183                                 else if (errno == EPIPE || errno == ECONNRESET || errno == EIO) {
184                                         f->master_readable = f->master_writable = false;
185                                         f->master_hangup = true;
186
187                                         f->master_event_source = sd_event_source_unref(f->master_event_source);
188                                 } else {
189                                         log_error_errno(errno, "read(): %m");
190                                         return sd_event_exit(f->event, EXIT_FAILURE);
191                                 }
192                         }  else
193                                 f->out_buffer_full += (size_t) k;
194                 }
195
196                 if (f->stdout_writable && f->out_buffer_full > 0) {
197
198                         k = write(STDOUT_FILENO, f->out_buffer, f->out_buffer_full);
199                         if (k < 0) {
200
201                                 if (errno == EAGAIN)
202                                         f->stdout_writable = false;
203                                 else if (errno == EIO || errno == EPIPE || errno == ECONNRESET) {
204                                         f->stdout_writable = false;
205                                         f->stdout_hangup = true;
206                                         f->stdout_event_source = sd_event_source_unref(f->stdout_event_source);
207                                 } else {
208                                         log_error_errno(errno, "write(): %m");
209                                         return sd_event_exit(f->event, EXIT_FAILURE);
210                                 }
211
212                         } else {
213
214                                 if (k > 0) {
215                                         f->last_char = f->out_buffer[k-1];
216                                         f->last_char_set = true;
217                                 }
218
219                                 assert(f->out_buffer_full >= (size_t) k);
220                                 memmove(f->out_buffer, f->out_buffer + k, f->out_buffer_full - k);
221                                 f->out_buffer_full -= k;
222                         }
223                 }
224         }
225
226         if (f->stdin_hangup || f->stdout_hangup || f->master_hangup) {
227                 /* Exit the loop if any side hung up and if there's
228                  * nothing more to write or nothing we could write. */
229
230                 if ((f->out_buffer_full <= 0 || f->stdout_hangup) &&
231                     (f->in_buffer_full <= 0 || f->master_hangup))
232                         return sd_event_exit(f->event, EXIT_SUCCESS);
233         }
234
235         return 0;
236 }
237
238 static int on_master_event(sd_event_source *e, int fd, uint32_t revents, void *userdata) {
239         PTYForward *f = userdata;
240
241         assert(f);
242         assert(e);
243         assert(e == f->master_event_source);
244         assert(fd >= 0);
245         assert(fd == f->master);
246
247         if (revents & (EPOLLIN|EPOLLHUP))
248                 f->master_readable = true;
249
250         if (revents & (EPOLLOUT|EPOLLHUP))
251                 f->master_writable = true;
252
253         return shovel(f);
254 }
255
256 static int on_stdin_event(sd_event_source *e, int fd, uint32_t revents, void *userdata) {
257         PTYForward *f = userdata;
258
259         assert(f);
260         assert(e);
261         assert(e == f->stdin_event_source);
262         assert(fd >= 0);
263         assert(fd == STDIN_FILENO);
264
265         if (revents & (EPOLLIN|EPOLLHUP))
266                 f->stdin_readable = true;
267
268         return shovel(f);
269 }
270
271 static int on_stdout_event(sd_event_source *e, int fd, uint32_t revents, void *userdata) {
272         PTYForward *f = userdata;
273
274         assert(f);
275         assert(e);
276         assert(e == f->stdout_event_source);
277         assert(fd >= 0);
278         assert(fd == STDOUT_FILENO);
279
280         if (revents & (EPOLLOUT|EPOLLHUP))
281                 f->stdout_writable = true;
282
283         return shovel(f);
284 }
285
286 static int on_sigwinch_event(sd_event_source *e, const struct signalfd_siginfo *si, void *userdata) {
287         PTYForward *f = userdata;
288         struct winsize ws;
289
290         assert(f);
291         assert(e);
292         assert(e == f->sigwinch_event_source);
293
294         /* The window size changed, let's forward that. */
295         if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &ws) >= 0)
296                 (void)ioctl(f->master, TIOCSWINSZ, &ws);
297
298         return 0;
299 }
300
301 int pty_forward_new(sd_event *event, int master, bool ignore_vhangup, PTYForward **ret) {
302         _cleanup_(pty_forward_freep) PTYForward *f = NULL;
303         struct winsize ws;
304         int r;
305
306         f = new0(PTYForward, 1);
307         if (!f)
308                 return -ENOMEM;
309
310         f->ignore_vhangup = ignore_vhangup;
311
312         if (event)
313                 f->event = sd_event_ref(event);
314         else {
315                 r = sd_event_default(&f->event);
316                 if (r < 0)
317                         return r;
318         }
319
320         r = fd_nonblock(STDIN_FILENO, true);
321         if (r < 0)
322                 return r;
323
324         r = fd_nonblock(STDOUT_FILENO, true);
325         if (r < 0)
326                 return r;
327
328         r = fd_nonblock(master, true);
329         if (r < 0)
330                 return r;
331
332         f->master = master;
333
334         if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &ws) >= 0)
335                 (void)ioctl(master, TIOCSWINSZ, &ws);
336
337         if (tcgetattr(STDIN_FILENO, &f->saved_stdin_attr) >= 0) {
338                 struct termios raw_stdin_attr;
339
340                 f->saved_stdin = true;
341
342                 raw_stdin_attr = f->saved_stdin_attr;
343                 cfmakeraw(&raw_stdin_attr);
344                 raw_stdin_attr.c_oflag = f->saved_stdin_attr.c_oflag;
345                 tcsetattr(STDIN_FILENO, TCSANOW, &raw_stdin_attr);
346         }
347
348         if (tcgetattr(STDOUT_FILENO, &f->saved_stdout_attr) >= 0) {
349                 struct termios raw_stdout_attr;
350
351                 f->saved_stdout = true;
352
353                 raw_stdout_attr = f->saved_stdout_attr;
354                 cfmakeraw(&raw_stdout_attr);
355                 raw_stdout_attr.c_iflag = f->saved_stdout_attr.c_iflag;
356                 raw_stdout_attr.c_lflag = f->saved_stdout_attr.c_lflag;
357                 tcsetattr(STDOUT_FILENO, TCSANOW, &raw_stdout_attr);
358         }
359
360         r = sd_event_add_io(f->event, &f->master_event_source, master, EPOLLIN|EPOLLOUT|EPOLLET, on_master_event, f);
361         if (r < 0)
362                 return r;
363
364         r = sd_event_add_io(f->event, &f->stdin_event_source, STDIN_FILENO, EPOLLIN|EPOLLET, on_stdin_event, f);
365         if (r < 0 && r != -EPERM)
366                 return r;
367
368         r = sd_event_add_io(f->event, &f->stdout_event_source, STDOUT_FILENO, EPOLLOUT|EPOLLET, on_stdout_event, f);
369         if (r == -EPERM)
370                 /* stdout without epoll support. Likely redirected to regular file. */
371                 f->stdout_writable = true;
372         else if (r < 0)
373                 return r;
374
375         r = sd_event_add_signal(f->event, &f->sigwinch_event_source, SIGWINCH, on_sigwinch_event, f);
376
377         *ret = f;
378         f = NULL;
379
380         return 0;
381 }
382
383 PTYForward *pty_forward_free(PTYForward *f) {
384
385         if (f) {
386                 sd_event_source_unref(f->stdin_event_source);
387                 sd_event_source_unref(f->stdout_event_source);
388                 sd_event_source_unref(f->master_event_source);
389                 sd_event_unref(f->event);
390
391                 if (f->saved_stdout)
392                         tcsetattr(STDOUT_FILENO, TCSANOW, &f->saved_stdout_attr);
393                 if (f->saved_stdin)
394                         tcsetattr(STDIN_FILENO, TCSANOW, &f->saved_stdin_attr);
395
396                 free(f);
397         }
398
399         /* STDIN/STDOUT should not be nonblocking normally, so let's
400          * unconditionally reset it */
401         fd_nonblock(STDIN_FILENO, false);
402         fd_nonblock(STDOUT_FILENO, false);
403
404         return NULL;
405 }
406
407 int pty_forward_get_last_char(PTYForward *f, char *ch) {
408         assert(f);
409         assert(ch);
410
411         if (!f->last_char_set)
412                 return -ENXIO;
413
414         *ch = f->last_char;
415         return 0;
416 }
417
418 int pty_forward_set_ignore_vhangup(PTYForward *f, bool ignore_vhangup) {
419         int r;
420
421         assert(f);
422
423         if (f->ignore_vhangup == ignore_vhangup)
424                 return 0;
425
426         f->ignore_vhangup = ignore_vhangup;
427         if (!f->ignore_vhangup) {
428
429                 /* We shall now react to vhangup()s? Let's check
430                  * immediately if we might be in one */
431
432                 f->master_readable = true;
433                 r = shovel(f);
434                 if (r < 0)
435                         return r;
436         }
437
438         return 0;
439 }
440
441 int pty_forward_get_ignore_vhangup(PTYForward *f) {
442         assert(f);
443
444         return f->ignore_vhangup;
445 }