chiark / gitweb /
362e8aae9f3760dbe0245a266d655a532c5557bf
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40 #include "build.h"
41 #include "set.h"
42 #include "path-util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
48 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
49
50 typedef struct Context {
51         Set *listen;
52         Set *connections;
53 } Context;
54
55 typedef struct Connection {
56         Context *context;
57
58         int server_fd, client_fd;
59         int server_to_client_buffer[2]; /* a pipe */
60         int client_to_server_buffer[2]; /* a pipe */
61
62         size_t server_to_client_buffer_full, client_to_server_buffer_full;
63         size_t server_to_client_buffer_size, client_to_server_buffer_size;
64
65         sd_event_source *server_event_source, *client_event_source;
66 } Connection;
67
68 static const char *arg_remote_host = NULL;
69 static int arg_listener = -1;
70
71 static void connection_free(Connection *c) {
72         assert(c);
73
74         if (c->context)
75                 set_remove(c->context->connections, c);
76
77         sd_event_source_unref(c->server_event_source);
78         sd_event_source_unref(c->client_event_source);
79
80         if (c->server_fd >= 0)
81                 close_nointr_nofail(c->server_fd);
82         if (c->client_fd >= 0)
83                 close_nointr_nofail(c->client_fd);
84
85         close_pipe(c->server_to_client_buffer);
86         close_pipe(c->client_to_server_buffer);
87
88         free(c);
89 }
90
91 static void context_free(Context *context) {
92         sd_event_source *es;
93         Connection *c;
94
95         assert(context);
96
97         while ((es = set_steal_first(context->listen)))
98                 sd_event_source_unref(es);
99
100         while ((c = set_first(context->connections)))
101                 connection_free(c);
102
103         set_free(context->listen);
104         set_free(context->connections);
105 }
106
107 static int get_remote_sockaddr(union sockaddr_union *sa, socklen_t *salen) {
108         int r;
109
110         assert(sa);
111         assert(salen);
112
113         if (path_is_absolute(arg_remote_host)) {
114                 sa->un.sun_family = AF_UNIX;
115                 strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
116                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
117
118                 *salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa->un.sun_path);
119
120         } else if (arg_remote_host[0] == '@') {
121                 sa->un.sun_family = AF_UNIX;
122                 sa->un.sun_path[0] = 0;
123                 strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
124                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
125
126                 *salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
127
128         } else {
129                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
130                 const char *node, *service;
131
132                 struct addrinfo hints = {
133                         .ai_family = AF_UNSPEC,
134                         .ai_socktype = SOCK_STREAM,
135                         .ai_flags = AI_ADDRCONFIG
136                 };
137
138                 service = strrchr(arg_remote_host, ':');
139                 if (service) {
140                         node = strndupa(arg_remote_host, service - arg_remote_host);
141                         service ++;
142                 } else {
143                         node = arg_remote_host;
144                         service = "80";
145                 }
146
147                 log_debug("Looking up address info for %s:%s", node, service);
148                 r = getaddrinfo(node, service, &hints, &result);
149                 if (r != 0) {
150                         log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r));
151                         return -EHOSTUNREACH;
152                 }
153
154                 assert(result);
155                 if (result->ai_addrlen > sizeof(union sockaddr_union)) {
156                         log_error("Address too long.");
157                         return -E2BIG;
158                 }
159
160                 memcpy(sa, result->ai_addr, result->ai_addrlen);
161                 *salen = result->ai_addrlen;
162         }
163
164         return 0;
165 }
166
167 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
168         int r;
169
170         assert(c);
171         assert(buffer);
172         assert(sz);
173
174         if (buffer[0] >= 0)
175                 return 0;
176
177         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
178         if (r < 0) {
179                 log_error("Failed to allocate pipe buffer: %m");
180                 return -errno;
181         }
182
183         fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
184
185         r = fcntl(buffer[0], F_GETPIPE_SZ);
186         if (r < 0) {
187                 log_error("Failed to get pipe buffer size: %m");
188                 return -errno;
189         }
190
191         assert(r > 0);
192         *sz = r;
193
194         return 0;
195 }
196
197 static int connection_shovel(
198                 Connection *c,
199                 int *from, int buffer[2], int *to,
200                 size_t *full, size_t *sz,
201                 sd_event_source **from_source, sd_event_source **to_source) {
202
203         bool shoveled;
204
205         assert(c);
206         assert(from);
207         assert(buffer);
208         assert(buffer[0] >= 0);
209         assert(buffer[1] >= 0);
210         assert(to);
211         assert(full);
212         assert(sz);
213         assert(from_source);
214         assert(to_source);
215
216         do {
217                 ssize_t z;
218
219                 shoveled = false;
220
221                 if (*full < *sz && *from >= 0 && *to >= 0) {
222                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
223                         if (z > 0) {
224                                 *full += z;
225                                 shoveled = true;
226                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
227                                 *from_source = sd_event_source_unref(*from_source);
228                                 close_nointr_nofail(*from);
229                                 *from = -1;
230                         } else if (errno != EAGAIN && errno != EINTR) {
231                                 log_error("Failed to splice: %m");
232                                 return -errno;
233                         }
234                 }
235
236                 if (*full > 0 && *to >= 0) {
237                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
238                         if (z > 0) {
239                                 *full -= z;
240                                 shoveled = true;
241                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
242                                 *to_source = sd_event_source_unref(*to_source);
243                                 close_nointr_nofail(*to);
244                                 *to = -1;
245                         } else if (errno != EAGAIN && errno != EINTR) {
246                                 log_error("Failed to splice: %m");
247                                 return -errno;
248                         }
249                 }
250         } while (shoveled);
251
252         return 0;
253 }
254
255 static int connection_enable_event_sources(Connection *c, sd_event *event);
256
257 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
258         Connection *c = userdata;
259         int r;
260
261         assert(s);
262         assert(fd >= 0);
263         assert(c);
264
265         r = connection_shovel(c,
266                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
267                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
268                               &c->server_event_source, &c->client_event_source);
269         if (r < 0)
270                 goto quit;
271
272         r = connection_shovel(c,
273                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
274                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
275                               &c->client_event_source, &c->server_event_source);
276         if (r < 0)
277                 goto quit;
278
279         /* EOF on both sides? */
280         if (c->server_fd == -1 && c->client_fd == -1)
281                 goto quit;
282
283         /* Server closed, and all data written to client? */
284         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
285                 goto quit;
286
287         /* Client closed, and all data written to server? */
288         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
289                 goto quit;
290
291         r = connection_enable_event_sources(c, sd_event_source_get_event(s));
292         if (r < 0)
293                 goto quit;
294
295         return 1;
296
297 quit:
298         connection_free(c);
299         return 0; /* ignore errors, continue serving */
300 }
301
302 static int connection_enable_event_sources(Connection *c, sd_event *event) {
303         uint32_t a = 0, b = 0;
304         int r;
305
306         assert(c);
307         assert(event);
308
309         if (c->server_to_client_buffer_full > 0)
310                 b |= EPOLLOUT;
311         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
312                 a |= EPOLLIN;
313
314         if (c->client_to_server_buffer_full > 0)
315                 a |= EPOLLOUT;
316         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
317                 b |= EPOLLIN;
318
319         if (c->server_event_source)
320                 r = sd_event_source_set_io_events(c->server_event_source, a);
321         else if (c->server_fd >= 0)
322                 r = sd_event_add_io(event, c->server_fd, a, traffic_cb, c, &c->server_event_source);
323         else
324                 r = 0;
325
326         if (r < 0) {
327                 log_error("Failed to set up server event source: %s", strerror(-r));
328                 return r;
329         }
330
331         if (c->client_event_source)
332                 r = sd_event_source_set_io_events(c->client_event_source, b);
333         else if (c->client_fd >= 0)
334                 r = sd_event_add_io(event, c->client_fd, b, traffic_cb, c, &c->client_event_source);
335         else
336                 r = 0;
337
338         if (r < 0) {
339                 log_error("Failed to set up client event source: %s", strerror(-r));
340                 return r;
341         }
342
343         return 0;
344 }
345
346 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
347         Connection *c = userdata;
348         socklen_t solen;
349         int error, r;
350
351         assert(s);
352         assert(fd >= 0);
353         assert(c);
354
355         solen = sizeof(error);
356         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
357         if (r < 0) {
358                 log_error("Failed to issue SO_ERROR: %m");
359                 goto fail;
360         }
361
362         if (error != 0) {
363                 log_error("Failed to connect to remote host: %s", strerror(error));
364                 goto fail;
365         }
366
367         c->client_event_source = sd_event_source_unref(c->client_event_source);
368
369         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
370         if (r < 0)
371                 goto fail;
372
373         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
374         if (r < 0)
375                 goto fail;
376
377         r = connection_enable_event_sources(c, sd_event_source_get_event(s));
378         if (r < 0)
379                 goto fail;
380
381         return 0;
382
383 fail:
384         connection_free(c);
385         return 0; /* ignore errors, continue serving */
386 }
387
388 static int add_connection_socket(Context *context, sd_event *event, int fd) {
389         union sockaddr_union sa = {};
390         socklen_t salen;
391         Connection *c;
392         int r;
393
394         assert(context);
395         assert(event);
396         assert(fd >= 0);
397
398         if (set_size(context->connections) > CONNECTIONS_MAX) {
399                 log_warning("Hit connection limit, refusing connection.");
400                 close_nointr_nofail(fd);
401                 return 0;
402         }
403
404         r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
405         if (r < 0)
406                 return log_oom();
407
408         c = new0(Connection, 1);
409         if (!c)
410                 return log_oom();
411
412         c->context = context;
413         c->server_fd = fd;
414         c->client_fd = -1;
415         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
416         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
417
418         r = set_put(context->connections, c);
419         if (r < 0) {
420                 free(c);
421                 return log_oom();
422         }
423
424         r = get_remote_sockaddr(&sa, &salen);
425         if (r < 0)
426                 goto fail;
427
428         c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
429         if (c->client_fd < 0) {
430                 log_error("Failed to get remote socket: %m");
431                 goto fail;
432         }
433
434         r = connect(c->client_fd, &sa.sa, salen);
435         if (r < 0) {
436                 if (errno == EINPROGRESS) {
437                         r = sd_event_add_io(event, c->client_fd, EPOLLOUT, connect_cb, c, &c->client_event_source);
438                         if (r < 0) {
439                                 log_error("Failed to add connection socket: %s", strerror(-r));
440                                 goto fail;
441                         }
442
443                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
444                         if (r < 0) {
445                                 log_error("Failed to enable oneshot event source: %s", strerror(-r));
446                                 goto fail;
447                         }
448                 } else {
449                         log_error("Failed to connect to remote host: %m");
450                         goto fail;
451                 }
452         } else {
453                 r = connection_enable_event_sources(c, event);
454                 if (r < 0)
455                         goto fail;
456         }
457
458         return 0;
459
460 fail:
461         connection_free(c);
462         return 0; /* ignore non-OOM errors, continue serving */
463 }
464
465 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
466         Context *context = userdata;
467         int nfd = -1, r;
468
469         assert(s);
470         assert(fd >= 0);
471         assert(revents & EPOLLIN);
472         assert(context);
473
474         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
475         if (nfd >= 0) {
476                 _cleanup_free_ char *peer = NULL;
477
478                 getpeername_pretty(nfd, &peer);
479                 log_debug("New connection from %s", strna(peer));
480
481                 r = add_connection_socket(context, sd_event_source_get_event(s), nfd);
482                 if (r < 0) {
483                         close_nointr_nofail(fd);
484                         return r;
485                 }
486
487         } else if (errno != -EAGAIN)
488                 log_warning("Failed to accept() socket: %m");
489
490         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
491         if (r < 0) {
492                 log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r));
493                 return r;
494         }
495
496         return 1;
497 }
498
499 static int add_listen_socket(Context *context, sd_event *event, int fd) {
500         sd_event_source *source;
501         int r;
502
503         assert(context);
504         assert(event);
505         assert(fd >= 0);
506
507         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
508         if (r < 0) {
509                 log_oom();
510                 return r;
511         }
512
513         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
514         if (r < 0) {
515                 log_error("Failed to determine socket type: %s", strerror(-r));
516                 return r;
517         }
518         if (r == 0) {
519                 log_error("Passed in socket is not a stream socket.");
520                 return -EINVAL;
521         }
522
523         r = fd_nonblock(fd, true);
524         if (r < 0) {
525                 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
526                 return r;
527         }
528
529         r = sd_event_add_io(event, fd, EPOLLIN, accept_cb, context, &source);
530         if (r < 0) {
531                 log_error("Failed to add event source: %s", strerror(-r));
532                 return r;
533         }
534
535         r = set_put(context->listen, source);
536         if (r < 0) {
537                 log_error("Failed to add source to set: %s", strerror(-r));
538                 sd_event_source_unref(source);
539                 return r;
540         }
541
542         /* Set the watcher to oneshot in case other processes are also
543          * watching to accept(). */
544         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
545         if (r < 0) {
546                 log_error("Failed to enable oneshot mode: %s", strerror(-r));
547                 return r;
548         }
549
550         return 0;
551 }
552
553 static int help(void) {
554
555         printf("%s [HOST:PORT]\n"
556                "%s [SOCKET]\n\n"
557                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
558                "  -l --listener=FD  Listen on a specific, single file descriptor.\n"
559                "  -h --help         Show this help\n"
560                "     --version      Show package version\n",
561                program_invocation_short_name,
562                program_invocation_short_name);
563
564         return 0;
565 }
566
567 static int parse_argv(int argc, char *argv[]) {
568
569         enum {
570                 ARG_VERSION = 0x100
571         };
572
573         static const struct option options[] = {
574                 { "help",     no_argument,       NULL, 'h'         },
575                 { "version",  no_argument,       NULL, ARG_VERSION },
576                 { "listener", required_argument, NULL, 'l'         },
577                 {}
578         };
579
580         int c, fd;
581
582         assert(argc >= 0);
583         assert(argv);
584
585         while ((c = getopt_long(argc, argv, "hl:", options, NULL)) >= 0) {
586
587                 switch (c) {
588
589                 case 'h':
590                         return help();
591
592                 case ARG_VERSION:
593                         puts(PACKAGE_STRING);
594                         puts(SYSTEMD_FEATURES);
595                         return 0;
596
597                 case 'l':
598                         if (safe_atoi(optarg, &fd) < 0) {
599                                 log_error("Failed to parse listener file descriptor: %s", optarg);
600                                 return -EINVAL;
601                         }
602                         if (fd < SD_LISTEN_FDS_START) {
603                                 log_error("Listener file descriptor must be at least %d.", SD_LISTEN_FDS_START);
604                                 return -EINVAL;
605                         }
606                         arg_listener = fd;
607                         break;
608
609                 case '?':
610                         return -EINVAL;
611
612                 default:
613                         assert_not_reached("Unhandled option");
614                 }
615         }
616
617         if (optind >= argc) {
618                 log_error("Not enough parameters.");
619                 return -EINVAL;
620         }
621
622         if (argc != optind+1) {
623                 log_error("Too many parameters.");
624                 return -EINVAL;
625         }
626
627         arg_remote_host = argv[optind];
628         return 1;
629 }
630
631 int main(int argc, char *argv[]) {
632         _cleanup_event_unref_ sd_event *event = NULL;
633         Context context = {};
634         int r, n, fd;
635
636         log_parse_environment();
637         log_open();
638
639         r = parse_argv(argc, argv);
640         if (r <= 0)
641                 goto finish;
642
643         r = sd_event_default(&event);
644         if (r < 0) {
645                 log_error("Failed to allocate event loop: %s", strerror(-r));
646                 goto finish;
647         }
648
649         if (arg_listener == -1) {
650                 n = sd_listen_fds(1);
651                 if (n < 0) {
652                         log_error("Failed to receive sockets from parent.");
653                         r = n;
654                         goto finish;
655                 } else if (n == 0) {
656                         log_error("Didn't get any sockets passed in.");
657                         r = -EINVAL;
658                         goto finish;
659                 }
660                 log_info("Listening on %d inherited socket(s), starting with fd=%d.", n, SD_LISTEN_FDS_START);
661                 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
662                         r = add_listen_socket(&context, event, fd);
663                         if (r < 0)
664                                 goto finish;
665                 }
666         } else {
667                 log_info("Listening on single inherited socket fd=%d.", arg_listener);
668                 r = add_listen_socket(&context, event, arg_listener);
669                 if (r < 0)
670                         goto finish;
671         }
672
673         r = sd_event_loop(event);
674         if (r < 0) {
675                 log_error("Failed to run event loop: %s", strerror(-r));
676                 goto finish;
677         }
678
679 finish:
680         context_free(&context);
681
682         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
683 }