chiark / gitweb /
event: make sure we keep a reference to all events we dispatch while we do so.
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40 #include "build.h"
41 #include "set.h"
42 #include "path-util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
48 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
49
50 typedef struct Context {
51         Set *listen;
52         Set *connections;
53 } Context;
54
55 typedef struct Connection {
56         int server_fd, client_fd;
57         int server_to_client_buffer[2]; /* a pipe */
58         int client_to_server_buffer[2]; /* a pipe */
59
60         size_t server_to_client_buffer_full, client_to_server_buffer_full;
61         size_t server_to_client_buffer_size, client_to_server_buffer_size;
62
63         sd_event_source *server_event_source, *client_event_source;
64 } Connection;
65
66 union sockaddr_any {
67         struct sockaddr sa;
68         struct sockaddr_un un;
69         struct sockaddr_in in;
70         struct sockaddr_in6 in6;
71         struct sockaddr_storage storage;
72 };
73
74 static const char *arg_remote_host = NULL;
75
76 static void connection_free(Connection *c) {
77         assert(c);
78
79         sd_event_source_unref(c->server_event_source);
80         sd_event_source_unref(c->client_event_source);
81
82         if (c->server_fd >= 0)
83                 close_nointr_nofail(c->server_fd);
84         if (c->client_fd >= 0)
85                 close_nointr_nofail(c->client_fd);
86
87         close_pipe(c->server_to_client_buffer);
88         close_pipe(c->client_to_server_buffer);
89
90         free(c);
91 }
92
93 static void context_free(Context *context) {
94         sd_event_source *es;
95         Connection *c;
96
97         assert(context);
98
99         while ((es = set_steal_first(context->listen)))
100                 sd_event_source_unref(es);
101
102         while ((c = set_steal_first(context->connections)))
103                 connection_free(c);
104
105         set_free(context->listen);
106         set_free(context->connections);
107 }
108
109 static int get_remote_sockaddr(union sockaddr_any *sa, socklen_t *salen) {
110         int r;
111
112         assert(sa);
113         assert(salen);
114
115         if (path_is_absolute(arg_remote_host)) {
116                 sa->un.sun_family = AF_UNIX;
117                 strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
118                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
119
120                 *salen = offsetof(union sockaddr_any, un.sun_path) + strlen(sa->un.sun_path);
121
122         } else if (arg_remote_host[0] == '@') {
123                 sa->un.sun_family = AF_UNIX;
124                 sa->un.sun_path[0] = 0;
125                 strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
126                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
127
128                 *salen = offsetof(union sockaddr_any, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
129
130         } else {
131                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
132                 const char *node, *service;
133
134                 struct addrinfo hints = {
135                         .ai_family = AF_UNSPEC,
136                         .ai_socktype = SOCK_STREAM,
137                         .ai_flags = AI_ADDRCONFIG
138                 };
139
140                 service = strrchr(arg_remote_host, ':');
141                 if (service) {
142                         node = strndupa(arg_remote_host, service - arg_remote_host);
143                         service ++;
144                 } else {
145                         node = arg_remote_host;
146                         service = "80";
147                 }
148
149                 log_debug("Looking up address info for %s:%s", node, service);
150                 r = getaddrinfo(node, service, &hints, &result);
151                 if (r != 0) {
152                         log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r));
153                         return -EHOSTUNREACH;
154                 }
155
156                 assert(result);
157                 if (result->ai_addrlen > sizeof(union sockaddr_any)) {
158                         log_error("Address too long.");
159                         return -E2BIG;
160                 }
161
162                 memcpy(sa, result->ai_addr, result->ai_addrlen);
163                 *salen = result->ai_addrlen;
164         }
165
166         return 0;
167 }
168
169 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
170         int r;
171
172         assert(c);
173         assert(buffer);
174         assert(sz);
175
176         if (buffer[0] >= 0)
177                 return 0;
178
179         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
180         if (r < 0) {
181                 log_error("Failed to allocate pipe buffer: %m");
182                 return -errno;
183         }
184
185         fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
186
187         r = fcntl(buffer[0], F_GETPIPE_SZ);
188         if (r < 0) {
189                 log_error("Failed to get pipe buffer size: %m");
190                 return -errno;
191         }
192
193         assert(r > 0);
194         *sz = r;
195
196         return 0;
197 }
198
199 static int connection_shovel(
200                 Connection *c,
201                 int *from, int buffer[2], int *to,
202                 size_t *full, size_t *sz,
203                 sd_event_source **from_source, sd_event_source **to_source) {
204
205         bool shoveled;
206
207         assert(c);
208         assert(from);
209         assert(buffer);
210         assert(buffer[0] >= 0);
211         assert(buffer[1] >= 0);
212         assert(to);
213         assert(full);
214         assert(sz);
215         assert(from_source);
216         assert(to_source);
217
218         do {
219                 ssize_t z;
220
221                 shoveled = false;
222
223                 if (*full < *sz && *from >= 0 && *to >= 0) {
224                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
225                         if (z > 0) {
226                                 *full += z;
227                                 shoveled = true;
228                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
229                                 *from_source = sd_event_source_unref(*from_source);
230                                 close_nointr_nofail(*from);
231                                 *from = -1;
232                         } else if (errno != EAGAIN && errno != EINTR) {
233                                 log_error("Failed to splice: %m");
234                                 return -errno;
235                         }
236                 }
237
238                 if (*full > 0 && *to >= 0) {
239                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
240                         if (z > 0) {
241                                 *full -= z;
242                                 shoveled = true;
243                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
244                                 *to_source = sd_event_source_unref(*to_source);
245                                 close_nointr_nofail(*to);
246                                 *to = -1;
247                         } else if (errno != EAGAIN && errno != EINTR) {
248                                 log_error("Failed to splice: %m");
249                                 return -errno;
250                         }
251                 }
252         } while (shoveled);
253
254         return 0;
255 }
256
257 static int connection_enable_event_sources(Connection *c, sd_event *event);
258
259 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
260         Connection *c = userdata;
261         int r;
262
263         assert(s);
264         assert(fd >= 0);
265         assert(c);
266
267         r = connection_shovel(c,
268                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
269                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
270                               &c->server_event_source, &c->client_event_source);
271         if (r < 0)
272                 goto quit;
273
274         r = connection_shovel(c,
275                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
276                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
277                               &c->client_event_source, &c->server_event_source);
278         if (r < 0)
279                 goto quit;
280
281         /* EOF on both sides? */
282         if (c->server_fd == -1 && c->client_fd == -1)
283                 goto quit;
284
285         /* Server closed, and all data written to client? */
286         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
287                 goto quit;
288
289         /* Client closed, and all data written to server? */
290         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
291                 goto quit;
292
293         r = connection_enable_event_sources(c, sd_event_get(s));
294         if (r < 0)
295                 goto quit;
296
297         return 1;
298
299 quit:
300         connection_free(c);
301         return 0; /* ignore errors, continue serving */
302 }
303
304 static int connection_enable_event_sources(Connection *c, sd_event *event) {
305         uint32_t a = 0, b = 0;
306         int r;
307
308         assert(c);
309         assert(event);
310
311         if (c->server_to_client_buffer_full > 0)
312                 b |= EPOLLOUT;
313         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
314                 a |= EPOLLIN;
315
316         if (c->client_to_server_buffer_full > 0)
317                 a |= EPOLLOUT;
318         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
319                 b |= EPOLLIN;
320
321         if (c->server_event_source)
322                 r = sd_event_source_set_io_events(c->server_event_source, a);
323         else if (c->server_fd >= 0)
324                 r = sd_event_add_io(event, c->server_fd, a, traffic_cb, c, &c->server_event_source);
325         else
326                 r = 0;
327
328         if (r < 0) {
329                 log_error("Failed to set up server event source: %s", strerror(-r));
330                 return r;
331         }
332
333         if (c->client_event_source)
334                 r = sd_event_source_set_io_events(c->client_event_source, b);
335         else if (c->client_fd >= 0)
336                 r = sd_event_add_io(event, c->client_fd, b, traffic_cb, c, &c->client_event_source);
337         else
338                 r = 0;
339
340         if (r < 0) {
341                 log_error("Failed to set up client event source: %s", strerror(-r));
342                 return r;
343         }
344
345         return 0;
346 }
347
348 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
349         Connection *c = userdata;
350         socklen_t solen;
351         int error, r;
352
353         assert(s);
354         assert(fd >= 0);
355         assert(c);
356
357         solen = sizeof(error);
358         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
359         if (r < 0) {
360                 log_error("Failed to issue SO_ERROR: %m");
361                 goto fail;
362         }
363
364         if (error != 0) {
365                 log_error("Failed to connect to remote host: %s", strerror(error));
366                 goto fail;
367         }
368
369         c->client_event_source = sd_event_source_unref(c->client_event_source);
370
371         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
372         if (r < 0)
373                 goto fail;
374
375         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
376         if (r < 0)
377                 goto fail;
378
379         r = connection_enable_event_sources(c, sd_event_get(s));
380         if (r < 0)
381                 goto fail;
382
383         return 0;
384
385 fail:
386         connection_free(c);
387         return 0; /* ignore errors, continue serving */
388 }
389
390 static int add_connection_socket(Context *context, sd_event *event, int fd) {
391         union sockaddr_any sa = {};
392         socklen_t salen;
393         Connection *c;
394         int r;
395
396         assert(context);
397         assert(event);
398         assert(fd >= 0);
399
400         if (set_size(context->connections) > CONNECTIONS_MAX) {
401                 log_warning("Hit connection limit, refusing connection.");
402                 close_nointr_nofail(fd);
403                 return 0;
404         }
405
406         r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
407         if (r < 0)
408                 return log_oom();
409
410         c = new0(Connection, 1);
411         if (!c)
412                 return log_oom();
413
414         c->server_fd = fd;
415         c->client_fd = -1;
416         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
417         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
418
419         r = get_remote_sockaddr(&sa, &salen);
420         if (r < 0)
421                 goto fail;
422
423         c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
424         if (c->client_fd < 0) {
425                 log_error("Failed to get remote socket: %m");
426                 goto fail;
427         }
428
429         r = connect(c->client_fd, &sa.sa, salen);
430         if (r < 0) {
431                 if (errno == EINPROGRESS) {
432                         r = sd_event_add_io(event, c->client_fd, EPOLLOUT, connect_cb, c, &c->client_event_source);
433                         if (r < 0) {
434                                 log_error("Failed to add connection socket: %s", strerror(-r));
435                                 goto fail;
436                         }
437
438                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
439                         if (r < 0) {
440                                 log_error("Failed to enable oneshot event source: %s", strerror(-r));
441                                 goto fail;
442                         }
443                 } else {
444                         log_error("Failed to connect to remote host: %m");
445                         goto fail;
446                 }
447         } else {
448                 r = connection_enable_event_sources(c, event);
449                 if (r < 0)
450                         goto fail;
451         }
452
453         return 0;
454
455 fail:
456         connection_free(c);
457         return 0; /* ignore non-OOM errors, continue serving */
458 }
459
460 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
461         Context *context = userdata;
462         int nfd = -1, r;
463
464         assert(s);
465         assert(fd >= 0);
466         assert(revents & EPOLLIN);
467         assert(context);
468
469         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
470         if (nfd >= 0) {
471                 _cleanup_free_ char *peer = NULL;
472
473                 getpeername_pretty(nfd, &peer);
474                 log_debug("New connection from %s", strna(peer));
475
476                 r = add_connection_socket(context, sd_event_get(s), nfd);
477                 if (r < 0) {
478                         close_nointr_nofail(fd);
479                         return r;
480                 }
481
482         } else if (errno != -EAGAIN)
483                 log_warning("Failed to accept() socket: %m");
484
485         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
486         if (r < 0) {
487                 log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r));
488                 return r;
489         }
490
491         return 1;
492 }
493
494 static int add_listen_socket(Context *context, sd_event *event, int fd) {
495         sd_event_source *source;
496         int r;
497
498         assert(context);
499         assert(event);
500         assert(fd >= 0);
501
502         log_info("Listening on %i", fd);
503
504         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
505         if (r < 0) {
506                 log_oom();
507                 return r;
508         }
509
510         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
511         if (r < 0) {
512                 log_error("Failed to determine socket type: %s", strerror(-r));
513                 return r;
514         }
515         if (r == 0) {
516                 log_error("Passed in socket is not a stream socket.");
517                 return -EINVAL;
518         }
519
520         r = fd_nonblock(fd, true);
521         if (r < 0) {
522                 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
523                 return r;
524         }
525
526         r = sd_event_add_io(event, fd, EPOLLIN, accept_cb, context, &source);
527         if (r < 0) {
528                 log_error("Failed to add event source: %s", strerror(-r));
529                 return r;
530         }
531
532         r = set_put(context->listen, source);
533         if (r < 0) {
534                 log_error("Failed to add source to set: %s", strerror(-r));
535                 sd_event_source_unref(source);
536                 return r;
537         }
538
539         /* Set the watcher to oneshot in case other processes are also
540          * watching to accept(). */
541         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
542         if (r < 0) {
543                 log_error("Failed to enable oneshot mode: %s", strerror(-r));
544                 return r;
545         }
546
547         return 0;
548 }
549
550 static int help(void) {
551
552         printf("%s [HOST:PORT]\n"
553                "%s [SOCKET]\n\n"
554                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
555                "  -h --help              Show this help\n"
556                "     --version           Show package version\n",
557                program_invocation_short_name,
558                program_invocation_short_name);
559
560         return 0;
561 }
562
563 static int parse_argv(int argc, char *argv[]) {
564
565         enum {
566                 ARG_VERSION = 0x100,
567                 ARG_IGNORE_ENV
568         };
569
570         static const struct option options[] = {
571                 { "help",       no_argument, NULL, 'h'           },
572                 { "version",    no_argument, NULL, ARG_VERSION   },
573                 {}
574         };
575
576         int c;
577
578         assert(argc >= 0);
579         assert(argv);
580
581         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
582
583                 switch (c) {
584
585                 case 'h':
586                         return help();
587
588                 case ARG_VERSION:
589                         puts(PACKAGE_STRING);
590                         puts(SYSTEMD_FEATURES);
591                         return 0;
592
593                 case '?':
594                         return -EINVAL;
595
596                 default:
597                         assert_not_reached("Unhandled option");
598                 }
599         }
600
601         if (optind >= argc) {
602                 log_error("Not enough parameters.");
603                 return -EINVAL;
604         }
605
606         if (argc != optind+1) {
607                 log_error("Too many parameters.");
608                 return -EINVAL;
609         }
610
611         arg_remote_host = argv[optind];
612         return 1;
613 }
614
615 int main(int argc, char *argv[]) {
616         _cleanup_event_unref_ sd_event *event = NULL;
617         Context context = {};
618         int r, n, fd;
619
620         log_parse_environment();
621         log_open();
622
623         r = parse_argv(argc, argv);
624         if (r <= 0)
625                 goto finish;
626
627         r = sd_event_new(&event);
628         if (r < 0) {
629                 log_error("Failed to allocate event loop: %s", strerror(-r));
630                 goto finish;
631         }
632
633         n = sd_listen_fds(1);
634         if (n < 0) {
635                 log_error("Failed to receive sockets from parent.");
636                 r = n;
637                 goto finish;
638         } else if (n == 0) {
639                 log_error("Didn't get any sockets passed in.");
640                 r = -EINVAL;
641                 goto finish;
642         }
643
644         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
645                 r = add_listen_socket(&context, event, fd);
646                 if (r < 0)
647                         goto finish;
648         }
649
650         r = sd_event_loop(event);
651         if (r < 0) {
652                 log_error("Failed to run event loop: %s", strerror(-r));
653                 goto finish;
654         }
655
656 finish:
657         context_free(&context);
658
659         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
660 }