chiark / gitweb /
bus: introduce concept of a "default" event loop per-thread and make use of it everywhere
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40 #include "build.h"
41 #include "set.h"
42 #include "path-util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
48 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
49
50 typedef struct Context {
51         Set *listen;
52         Set *connections;
53 } Context;
54
55 typedef struct Connection {
56         Context *context;
57
58         int server_fd, client_fd;
59         int server_to_client_buffer[2]; /* a pipe */
60         int client_to_server_buffer[2]; /* a pipe */
61
62         size_t server_to_client_buffer_full, client_to_server_buffer_full;
63         size_t server_to_client_buffer_size, client_to_server_buffer_size;
64
65         sd_event_source *server_event_source, *client_event_source;
66 } Connection;
67
68 static const char *arg_remote_host = NULL;
69
70 static void connection_free(Connection *c) {
71         assert(c);
72
73         if (c->context)
74                 set_remove(c->context->connections, c);
75
76         sd_event_source_unref(c->server_event_source);
77         sd_event_source_unref(c->client_event_source);
78
79         if (c->server_fd >= 0)
80                 close_nointr_nofail(c->server_fd);
81         if (c->client_fd >= 0)
82                 close_nointr_nofail(c->client_fd);
83
84         close_pipe(c->server_to_client_buffer);
85         close_pipe(c->client_to_server_buffer);
86
87         free(c);
88 }
89
90 static void context_free(Context *context) {
91         sd_event_source *es;
92         Connection *c;
93
94         assert(context);
95
96         while ((es = set_steal_first(context->listen)))
97                 sd_event_source_unref(es);
98
99         while ((c = set_first(context->connections)))
100                 connection_free(c);
101
102         set_free(context->listen);
103         set_free(context->connections);
104 }
105
106 static int get_remote_sockaddr(union sockaddr_union *sa, socklen_t *salen) {
107         int r;
108
109         assert(sa);
110         assert(salen);
111
112         if (path_is_absolute(arg_remote_host)) {
113                 sa->un.sun_family = AF_UNIX;
114                 strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
115                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
116
117                 *salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa->un.sun_path);
118
119         } else if (arg_remote_host[0] == '@') {
120                 sa->un.sun_family = AF_UNIX;
121                 sa->un.sun_path[0] = 0;
122                 strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
123                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
124
125                 *salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
126
127         } else {
128                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
129                 const char *node, *service;
130
131                 struct addrinfo hints = {
132                         .ai_family = AF_UNSPEC,
133                         .ai_socktype = SOCK_STREAM,
134                         .ai_flags = AI_ADDRCONFIG
135                 };
136
137                 service = strrchr(arg_remote_host, ':');
138                 if (service) {
139                         node = strndupa(arg_remote_host, service - arg_remote_host);
140                         service ++;
141                 } else {
142                         node = arg_remote_host;
143                         service = "80";
144                 }
145
146                 log_debug("Looking up address info for %s:%s", node, service);
147                 r = getaddrinfo(node, service, &hints, &result);
148                 if (r != 0) {
149                         log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r));
150                         return -EHOSTUNREACH;
151                 }
152
153                 assert(result);
154                 if (result->ai_addrlen > sizeof(union sockaddr_union)) {
155                         log_error("Address too long.");
156                         return -E2BIG;
157                 }
158
159                 memcpy(sa, result->ai_addr, result->ai_addrlen);
160                 *salen = result->ai_addrlen;
161         }
162
163         return 0;
164 }
165
166 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
167         int r;
168
169         assert(c);
170         assert(buffer);
171         assert(sz);
172
173         if (buffer[0] >= 0)
174                 return 0;
175
176         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
177         if (r < 0) {
178                 log_error("Failed to allocate pipe buffer: %m");
179                 return -errno;
180         }
181
182         fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
183
184         r = fcntl(buffer[0], F_GETPIPE_SZ);
185         if (r < 0) {
186                 log_error("Failed to get pipe buffer size: %m");
187                 return -errno;
188         }
189
190         assert(r > 0);
191         *sz = r;
192
193         return 0;
194 }
195
196 static int connection_shovel(
197                 Connection *c,
198                 int *from, int buffer[2], int *to,
199                 size_t *full, size_t *sz,
200                 sd_event_source **from_source, sd_event_source **to_source) {
201
202         bool shoveled;
203
204         assert(c);
205         assert(from);
206         assert(buffer);
207         assert(buffer[0] >= 0);
208         assert(buffer[1] >= 0);
209         assert(to);
210         assert(full);
211         assert(sz);
212         assert(from_source);
213         assert(to_source);
214
215         do {
216                 ssize_t z;
217
218                 shoveled = false;
219
220                 if (*full < *sz && *from >= 0 && *to >= 0) {
221                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
222                         if (z > 0) {
223                                 *full += z;
224                                 shoveled = true;
225                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
226                                 *from_source = sd_event_source_unref(*from_source);
227                                 close_nointr_nofail(*from);
228                                 *from = -1;
229                         } else if (errno != EAGAIN && errno != EINTR) {
230                                 log_error("Failed to splice: %m");
231                                 return -errno;
232                         }
233                 }
234
235                 if (*full > 0 && *to >= 0) {
236                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
237                         if (z > 0) {
238                                 *full -= z;
239                                 shoveled = true;
240                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
241                                 *to_source = sd_event_source_unref(*to_source);
242                                 close_nointr_nofail(*to);
243                                 *to = -1;
244                         } else if (errno != EAGAIN && errno != EINTR) {
245                                 log_error("Failed to splice: %m");
246                                 return -errno;
247                         }
248                 }
249         } while (shoveled);
250
251         return 0;
252 }
253
254 static int connection_enable_event_sources(Connection *c, sd_event *event);
255
256 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
257         Connection *c = userdata;
258         int r;
259
260         assert(s);
261         assert(fd >= 0);
262         assert(c);
263
264         r = connection_shovel(c,
265                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
266                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
267                               &c->server_event_source, &c->client_event_source);
268         if (r < 0)
269                 goto quit;
270
271         r = connection_shovel(c,
272                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
273                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
274                               &c->client_event_source, &c->server_event_source);
275         if (r < 0)
276                 goto quit;
277
278         /* EOF on both sides? */
279         if (c->server_fd == -1 && c->client_fd == -1)
280                 goto quit;
281
282         /* Server closed, and all data written to client? */
283         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
284                 goto quit;
285
286         /* Client closed, and all data written to server? */
287         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
288                 goto quit;
289
290         r = connection_enable_event_sources(c, sd_event_get(s));
291         if (r < 0)
292                 goto quit;
293
294         return 1;
295
296 quit:
297         connection_free(c);
298         return 0; /* ignore errors, continue serving */
299 }
300
301 static int connection_enable_event_sources(Connection *c, sd_event *event) {
302         uint32_t a = 0, b = 0;
303         int r;
304
305         assert(c);
306         assert(event);
307
308         if (c->server_to_client_buffer_full > 0)
309                 b |= EPOLLOUT;
310         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
311                 a |= EPOLLIN;
312
313         if (c->client_to_server_buffer_full > 0)
314                 a |= EPOLLOUT;
315         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
316                 b |= EPOLLIN;
317
318         if (c->server_event_source)
319                 r = sd_event_source_set_io_events(c->server_event_source, a);
320         else if (c->server_fd >= 0)
321                 r = sd_event_add_io(event, c->server_fd, a, traffic_cb, c, &c->server_event_source);
322         else
323                 r = 0;
324
325         if (r < 0) {
326                 log_error("Failed to set up server event source: %s", strerror(-r));
327                 return r;
328         }
329
330         if (c->client_event_source)
331                 r = sd_event_source_set_io_events(c->client_event_source, b);
332         else if (c->client_fd >= 0)
333                 r = sd_event_add_io(event, c->client_fd, b, traffic_cb, c, &c->client_event_source);
334         else
335                 r = 0;
336
337         if (r < 0) {
338                 log_error("Failed to set up client event source: %s", strerror(-r));
339                 return r;
340         }
341
342         return 0;
343 }
344
345 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
346         Connection *c = userdata;
347         socklen_t solen;
348         int error, r;
349
350         assert(s);
351         assert(fd >= 0);
352         assert(c);
353
354         solen = sizeof(error);
355         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
356         if (r < 0) {
357                 log_error("Failed to issue SO_ERROR: %m");
358                 goto fail;
359         }
360
361         if (error != 0) {
362                 log_error("Failed to connect to remote host: %s", strerror(error));
363                 goto fail;
364         }
365
366         c->client_event_source = sd_event_source_unref(c->client_event_source);
367
368         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
369         if (r < 0)
370                 goto fail;
371
372         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
373         if (r < 0)
374                 goto fail;
375
376         r = connection_enable_event_sources(c, sd_event_get(s));
377         if (r < 0)
378                 goto fail;
379
380         return 0;
381
382 fail:
383         connection_free(c);
384         return 0; /* ignore errors, continue serving */
385 }
386
387 static int add_connection_socket(Context *context, sd_event *event, int fd) {
388         union sockaddr_union sa = {};
389         socklen_t salen;
390         Connection *c;
391         int r;
392
393         assert(context);
394         assert(event);
395         assert(fd >= 0);
396
397         if (set_size(context->connections) > CONNECTIONS_MAX) {
398                 log_warning("Hit connection limit, refusing connection.");
399                 close_nointr_nofail(fd);
400                 return 0;
401         }
402
403         r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
404         if (r < 0)
405                 return log_oom();
406
407         c = new0(Connection, 1);
408         if (!c)
409                 return log_oom();
410
411         c->context = context;
412         c->server_fd = fd;
413         c->client_fd = -1;
414         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
415         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
416
417         r = set_put(context->connections, c);
418         if (r < 0) {
419                 free(c);
420                 return log_oom();
421         }
422
423         r = get_remote_sockaddr(&sa, &salen);
424         if (r < 0)
425                 goto fail;
426
427         c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
428         if (c->client_fd < 0) {
429                 log_error("Failed to get remote socket: %m");
430                 goto fail;
431         }
432
433         r = connect(c->client_fd, &sa.sa, salen);
434         if (r < 0) {
435                 if (errno == EINPROGRESS) {
436                         r = sd_event_add_io(event, c->client_fd, EPOLLOUT, connect_cb, c, &c->client_event_source);
437                         if (r < 0) {
438                                 log_error("Failed to add connection socket: %s", strerror(-r));
439                                 goto fail;
440                         }
441
442                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
443                         if (r < 0) {
444                                 log_error("Failed to enable oneshot event source: %s", strerror(-r));
445                                 goto fail;
446                         }
447                 } else {
448                         log_error("Failed to connect to remote host: %m");
449                         goto fail;
450                 }
451         } else {
452                 r = connection_enable_event_sources(c, event);
453                 if (r < 0)
454                         goto fail;
455         }
456
457         return 0;
458
459 fail:
460         connection_free(c);
461         return 0; /* ignore non-OOM errors, continue serving */
462 }
463
464 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
465         Context *context = userdata;
466         int nfd = -1, r;
467
468         assert(s);
469         assert(fd >= 0);
470         assert(revents & EPOLLIN);
471         assert(context);
472
473         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
474         if (nfd >= 0) {
475                 _cleanup_free_ char *peer = NULL;
476
477                 getpeername_pretty(nfd, &peer);
478                 log_debug("New connection from %s", strna(peer));
479
480                 r = add_connection_socket(context, sd_event_get(s), nfd);
481                 if (r < 0) {
482                         close_nointr_nofail(fd);
483                         return r;
484                 }
485
486         } else if (errno != -EAGAIN)
487                 log_warning("Failed to accept() socket: %m");
488
489         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
490         if (r < 0) {
491                 log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r));
492                 return r;
493         }
494
495         return 1;
496 }
497
498 static int add_listen_socket(Context *context, sd_event *event, int fd) {
499         sd_event_source *source;
500         int r;
501
502         assert(context);
503         assert(event);
504         assert(fd >= 0);
505
506         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
507         if (r < 0) {
508                 log_oom();
509                 return r;
510         }
511
512         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
513         if (r < 0) {
514                 log_error("Failed to determine socket type: %s", strerror(-r));
515                 return r;
516         }
517         if (r == 0) {
518                 log_error("Passed in socket is not a stream socket.");
519                 return -EINVAL;
520         }
521
522         r = fd_nonblock(fd, true);
523         if (r < 0) {
524                 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
525                 return r;
526         }
527
528         r = sd_event_add_io(event, fd, EPOLLIN, accept_cb, context, &source);
529         if (r < 0) {
530                 log_error("Failed to add event source: %s", strerror(-r));
531                 return r;
532         }
533
534         r = set_put(context->listen, source);
535         if (r < 0) {
536                 log_error("Failed to add source to set: %s", strerror(-r));
537                 sd_event_source_unref(source);
538                 return r;
539         }
540
541         /* Set the watcher to oneshot in case other processes are also
542          * watching to accept(). */
543         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
544         if (r < 0) {
545                 log_error("Failed to enable oneshot mode: %s", strerror(-r));
546                 return r;
547         }
548
549         return 0;
550 }
551
552 static int help(void) {
553
554         printf("%s [HOST:PORT]\n"
555                "%s [SOCKET]\n\n"
556                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
557                "  -h --help              Show this help\n"
558                "     --version           Show package version\n",
559                program_invocation_short_name,
560                program_invocation_short_name);
561
562         return 0;
563 }
564
565 static int parse_argv(int argc, char *argv[]) {
566
567         enum {
568                 ARG_VERSION = 0x100,
569                 ARG_IGNORE_ENV
570         };
571
572         static const struct option options[] = {
573                 { "help",       no_argument, NULL, 'h'           },
574                 { "version",    no_argument, NULL, ARG_VERSION   },
575                 {}
576         };
577
578         int c;
579
580         assert(argc >= 0);
581         assert(argv);
582
583         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
584
585                 switch (c) {
586
587                 case 'h':
588                         return help();
589
590                 case ARG_VERSION:
591                         puts(PACKAGE_STRING);
592                         puts(SYSTEMD_FEATURES);
593                         return 0;
594
595                 case '?':
596                         return -EINVAL;
597
598                 default:
599                         assert_not_reached("Unhandled option");
600                 }
601         }
602
603         if (optind >= argc) {
604                 log_error("Not enough parameters.");
605                 return -EINVAL;
606         }
607
608         if (argc != optind+1) {
609                 log_error("Too many parameters.");
610                 return -EINVAL;
611         }
612
613         arg_remote_host = argv[optind];
614         return 1;
615 }
616
617 int main(int argc, char *argv[]) {
618         _cleanup_event_unref_ sd_event *event = NULL;
619         Context context = {};
620         int r, n, fd;
621
622         log_parse_environment();
623         log_open();
624
625         r = parse_argv(argc, argv);
626         if (r <= 0)
627                 goto finish;
628
629         r = sd_event_default(&event);
630         if (r < 0) {
631                 log_error("Failed to allocate event loop: %s", strerror(-r));
632                 goto finish;
633         }
634
635         n = sd_listen_fds(1);
636         if (n < 0) {
637                 log_error("Failed to receive sockets from parent.");
638                 r = n;
639                 goto finish;
640         } else if (n == 0) {
641                 log_error("Didn't get any sockets passed in.");
642                 r = -EINVAL;
643                 goto finish;
644         }
645
646         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
647                 r = add_listen_socket(&context, event, fd);
648                 if (r < 0)
649                         goto finish;
650         }
651
652         r = sd_event_loop(event);
653         if (r < 0) {
654                 log_error("Failed to run event loop: %s", strerror(-r));
655                 goto finish;
656         }
657
658 finish:
659         context_free(&context);
660
661         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
662 }