chiark / gitweb /
api: in constructor function calls, always put the returned object pointer first...
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40 #include "build.h"
41 #include "set.h"
42 #include "path-util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
48 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
49
50 typedef struct Context {
51         Set *listen;
52         Set *connections;
53 } Context;
54
55 typedef struct Connection {
56         Context *context;
57
58         int server_fd, client_fd;
59         int server_to_client_buffer[2]; /* a pipe */
60         int client_to_server_buffer[2]; /* a pipe */
61
62         size_t server_to_client_buffer_full, client_to_server_buffer_full;
63         size_t server_to_client_buffer_size, client_to_server_buffer_size;
64
65         sd_event_source *server_event_source, *client_event_source;
66 } Connection;
67
68 static const char *arg_remote_host = NULL;
69
70 static void connection_free(Connection *c) {
71         assert(c);
72
73         if (c->context)
74                 set_remove(c->context->connections, c);
75
76         sd_event_source_unref(c->server_event_source);
77         sd_event_source_unref(c->client_event_source);
78
79         if (c->server_fd >= 0)
80                 close_nointr_nofail(c->server_fd);
81         if (c->client_fd >= 0)
82                 close_nointr_nofail(c->client_fd);
83
84         close_pipe(c->server_to_client_buffer);
85         close_pipe(c->client_to_server_buffer);
86
87         free(c);
88 }
89
90 static void context_free(Context *context) {
91         sd_event_source *es;
92         Connection *c;
93
94         assert(context);
95
96         while ((es = set_steal_first(context->listen)))
97                 sd_event_source_unref(es);
98
99         while ((c = set_first(context->connections)))
100                 connection_free(c);
101
102         set_free(context->listen);
103         set_free(context->connections);
104 }
105
106 static int get_remote_sockaddr(union sockaddr_union *sa, socklen_t *salen) {
107         int r;
108
109         assert(sa);
110         assert(salen);
111
112         if (path_is_absolute(arg_remote_host)) {
113                 sa->un.sun_family = AF_UNIX;
114                 strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
115                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
116
117                 *salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa->un.sun_path);
118
119         } else if (arg_remote_host[0] == '@') {
120                 sa->un.sun_family = AF_UNIX;
121                 sa->un.sun_path[0] = 0;
122                 strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
123                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
124
125                 *salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
126
127         } else {
128                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
129                 const char *node, *service;
130
131                 struct addrinfo hints = {
132                         .ai_family = AF_UNSPEC,
133                         .ai_socktype = SOCK_STREAM,
134                         .ai_flags = AI_ADDRCONFIG
135                 };
136
137                 service = strrchr(arg_remote_host, ':');
138                 if (service) {
139                         node = strndupa(arg_remote_host, service - arg_remote_host);
140                         service ++;
141                 } else {
142                         node = arg_remote_host;
143                         service = "80";
144                 }
145
146                 log_debug("Looking up address info for %s:%s", node, service);
147                 r = getaddrinfo(node, service, &hints, &result);
148                 if (r != 0) {
149                         log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r));
150                         return -EHOSTUNREACH;
151                 }
152
153                 assert(result);
154                 if (result->ai_addrlen > sizeof(union sockaddr_union)) {
155                         log_error("Address too long.");
156                         return -E2BIG;
157                 }
158
159                 memcpy(sa, result->ai_addr, result->ai_addrlen);
160                 *salen = result->ai_addrlen;
161         }
162
163         return 0;
164 }
165
166 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
167         int r;
168
169         assert(c);
170         assert(buffer);
171         assert(sz);
172
173         if (buffer[0] >= 0)
174                 return 0;
175
176         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
177         if (r < 0) {
178                 log_error("Failed to allocate pipe buffer: %m");
179                 return -errno;
180         }
181
182         fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
183
184         r = fcntl(buffer[0], F_GETPIPE_SZ);
185         if (r < 0) {
186                 log_error("Failed to get pipe buffer size: %m");
187                 return -errno;
188         }
189
190         assert(r > 0);
191         *sz = r;
192
193         return 0;
194 }
195
196 static int connection_shovel(
197                 Connection *c,
198                 int *from, int buffer[2], int *to,
199                 size_t *full, size_t *sz,
200                 sd_event_source **from_source, sd_event_source **to_source) {
201
202         bool shoveled;
203
204         assert(c);
205         assert(from);
206         assert(buffer);
207         assert(buffer[0] >= 0);
208         assert(buffer[1] >= 0);
209         assert(to);
210         assert(full);
211         assert(sz);
212         assert(from_source);
213         assert(to_source);
214
215         do {
216                 ssize_t z;
217
218                 shoveled = false;
219
220                 if (*full < *sz && *from >= 0 && *to >= 0) {
221                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
222                         if (z > 0) {
223                                 *full += z;
224                                 shoveled = true;
225                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
226                                 *from_source = sd_event_source_unref(*from_source);
227                                 close_nointr_nofail(*from);
228                                 *from = -1;
229                         } else if (errno != EAGAIN && errno != EINTR) {
230                                 log_error("Failed to splice: %m");
231                                 return -errno;
232                         }
233                 }
234
235                 if (*full > 0 && *to >= 0) {
236                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
237                         if (z > 0) {
238                                 *full -= z;
239                                 shoveled = true;
240                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
241                                 *to_source = sd_event_source_unref(*to_source);
242                                 close_nointr_nofail(*to);
243                                 *to = -1;
244                         } else if (errno != EAGAIN && errno != EINTR) {
245                                 log_error("Failed to splice: %m");
246                                 return -errno;
247                         }
248                 }
249         } while (shoveled);
250
251         return 0;
252 }
253
254 static int connection_enable_event_sources(Connection *c, sd_event *event);
255
256 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
257         Connection *c = userdata;
258         int r;
259
260         assert(s);
261         assert(fd >= 0);
262         assert(c);
263
264         r = connection_shovel(c,
265                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
266                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
267                               &c->server_event_source, &c->client_event_source);
268         if (r < 0)
269                 goto quit;
270
271         r = connection_shovel(c,
272                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
273                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
274                               &c->client_event_source, &c->server_event_source);
275         if (r < 0)
276                 goto quit;
277
278         /* EOF on both sides? */
279         if (c->server_fd == -1 && c->client_fd == -1)
280                 goto quit;
281
282         /* Server closed, and all data written to client? */
283         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
284                 goto quit;
285
286         /* Client closed, and all data written to server? */
287         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
288                 goto quit;
289
290         r = connection_enable_event_sources(c, sd_event_source_get_event(s));
291         if (r < 0)
292                 goto quit;
293
294         return 1;
295
296 quit:
297         connection_free(c);
298         return 0; /* ignore errors, continue serving */
299 }
300
301 static int connection_enable_event_sources(Connection *c, sd_event *event) {
302         uint32_t a = 0, b = 0;
303         int r;
304
305         assert(c);
306         assert(event);
307
308         if (c->server_to_client_buffer_full > 0)
309                 b |= EPOLLOUT;
310         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
311                 a |= EPOLLIN;
312
313         if (c->client_to_server_buffer_full > 0)
314                 a |= EPOLLOUT;
315         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
316                 b |= EPOLLIN;
317
318         if (c->server_event_source)
319                 r = sd_event_source_set_io_events(c->server_event_source, a);
320         else if (c->server_fd >= 0)
321                 r = sd_event_add_io(event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
322         else
323                 r = 0;
324
325         if (r < 0) {
326                 log_error("Failed to set up server event source: %s", strerror(-r));
327                 return r;
328         }
329
330         if (c->client_event_source)
331                 r = sd_event_source_set_io_events(c->client_event_source, b);
332         else if (c->client_fd >= 0)
333                 r = sd_event_add_io(event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
334         else
335                 r = 0;
336
337         if (r < 0) {
338                 log_error("Failed to set up client event source: %s", strerror(-r));
339                 return r;
340         }
341
342         return 0;
343 }
344
345 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
346         Connection *c = userdata;
347         socklen_t solen;
348         int error, r;
349
350         assert(s);
351         assert(fd >= 0);
352         assert(c);
353
354         solen = sizeof(error);
355         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
356         if (r < 0) {
357                 log_error("Failed to issue SO_ERROR: %m");
358                 goto fail;
359         }
360
361         if (error != 0) {
362                 log_error("Failed to connect to remote host: %s", strerror(error));
363                 goto fail;
364         }
365
366         c->client_event_source = sd_event_source_unref(c->client_event_source);
367
368         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
369         if (r < 0)
370                 goto fail;
371
372         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
373         if (r < 0)
374                 goto fail;
375
376         r = connection_enable_event_sources(c, sd_event_source_get_event(s));
377         if (r < 0)
378                 goto fail;
379
380         return 0;
381
382 fail:
383         connection_free(c);
384         return 0; /* ignore errors, continue serving */
385 }
386
387 static int add_connection_socket(Context *context, sd_event *event, int fd) {
388         union sockaddr_union sa = {};
389         socklen_t salen;
390         Connection *c;
391         int r;
392
393         assert(context);
394         assert(event);
395         assert(fd >= 0);
396
397         if (set_size(context->connections) > CONNECTIONS_MAX) {
398                 log_warning("Hit connection limit, refusing connection.");
399                 close_nointr_nofail(fd);
400                 return 0;
401         }
402
403         r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
404         if (r < 0)
405                 return log_oom();
406
407         c = new0(Connection, 1);
408         if (!c)
409                 return log_oom();
410
411         c->context = context;
412         c->server_fd = fd;
413         c->client_fd = -1;
414         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
415         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
416
417         r = set_put(context->connections, c);
418         if (r < 0) {
419                 free(c);
420                 return log_oom();
421         }
422
423         r = get_remote_sockaddr(&sa, &salen);
424         if (r < 0)
425                 goto fail;
426
427         c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
428         if (c->client_fd < 0) {
429                 log_error("Failed to get remote socket: %m");
430                 goto fail;
431         }
432
433         r = connect(c->client_fd, &sa.sa, salen);
434         if (r < 0) {
435                 if (errno == EINPROGRESS) {
436                         r = sd_event_add_io(event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
437                         if (r < 0) {
438                                 log_error("Failed to add connection socket: %s", strerror(-r));
439                                 goto fail;
440                         }
441
442                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
443                         if (r < 0) {
444                                 log_error("Failed to enable oneshot event source: %s", strerror(-r));
445                                 goto fail;
446                         }
447                 } else {
448                         log_error("Failed to connect to remote host: %m");
449                         goto fail;
450                 }
451         } else {
452                 r = connection_enable_event_sources(c, event);
453                 if (r < 0)
454                         goto fail;
455         }
456
457         return 0;
458
459 fail:
460         connection_free(c);
461         return 0; /* ignore non-OOM errors, continue serving */
462 }
463
464 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
465         _cleanup_free_ char *peer = NULL;
466         Context *context = userdata;
467         int nfd = -1, r;
468
469         assert(s);
470         assert(fd >= 0);
471         assert(revents & EPOLLIN);
472         assert(context);
473
474         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
475         if (nfd < 0) {
476                 if (errno != -EAGAIN)
477                         log_warning("Failed to accept() socket: %m");
478         } else {
479                 getpeername_pretty(nfd, &peer);
480                 log_debug("New connection from %s", strna(peer));
481
482                 r = add_connection_socket(context, sd_event_source_get_event(s), nfd);
483                 if (r < 0) {
484                         log_error("Failed to accept connection, ignoring: %s", strerror(-r));
485                         close_nointr_nofail(fd);
486                 }
487         }
488
489         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
490         if (r < 0) {
491                 log_error("Error while re-enabling listener with ONESHOT: %s", strerror(-r));
492                 sd_event_exit(sd_event_source_get_event(s), r);
493                 return r;
494         }
495
496         return 1;
497 }
498
499 static int add_listen_socket(Context *context, sd_event *event, int fd) {
500         sd_event_source *source;
501         int r;
502
503         assert(context);
504         assert(event);
505         assert(fd >= 0);
506
507         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
508         if (r < 0) {
509                 log_oom();
510                 return r;
511         }
512
513         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
514         if (r < 0) {
515                 log_error("Failed to determine socket type: %s", strerror(-r));
516                 return r;
517         }
518         if (r == 0) {
519                 log_error("Passed in socket is not a stream socket.");
520                 return -EINVAL;
521         }
522
523         r = fd_nonblock(fd, true);
524         if (r < 0) {
525                 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
526                 return r;
527         }
528
529         r = sd_event_add_io(event, &source, fd, EPOLLIN, accept_cb, context);
530         if (r < 0) {
531                 log_error("Failed to add event source: %s", strerror(-r));
532                 return r;
533         }
534
535         r = set_put(context->listen, source);
536         if (r < 0) {
537                 log_error("Failed to add source to set: %s", strerror(-r));
538                 sd_event_source_unref(source);
539                 return r;
540         }
541
542         /* Set the watcher to oneshot in case other processes are also
543          * watching to accept(). */
544         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
545         if (r < 0) {
546                 log_error("Failed to enable oneshot mode: %s", strerror(-r));
547                 return r;
548         }
549
550         return 0;
551 }
552
553 static int help(void) {
554
555         printf("%s [HOST:PORT]\n"
556                "%s [SOCKET]\n\n"
557                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
558                "  -h --help              Show this help\n"
559                "     --version           Show package version\n",
560                program_invocation_short_name,
561                program_invocation_short_name);
562
563         return 0;
564 }
565
566 static int parse_argv(int argc, char *argv[]) {
567
568         enum {
569                 ARG_VERSION = 0x100,
570                 ARG_IGNORE_ENV
571         };
572
573         static const struct option options[] = {
574                 { "help",       no_argument, NULL, 'h'           },
575                 { "version",    no_argument, NULL, ARG_VERSION   },
576                 {}
577         };
578
579         int c;
580
581         assert(argc >= 0);
582         assert(argv);
583
584         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
585
586                 switch (c) {
587
588                 case 'h':
589                         return help();
590
591                 case ARG_VERSION:
592                         puts(PACKAGE_STRING);
593                         puts(SYSTEMD_FEATURES);
594                         return 0;
595
596                 case '?':
597                         return -EINVAL;
598
599                 default:
600                         assert_not_reached("Unhandled option");
601                 }
602         }
603
604         if (optind >= argc) {
605                 log_error("Not enough parameters.");
606                 return -EINVAL;
607         }
608
609         if (argc != optind+1) {
610                 log_error("Too many parameters.");
611                 return -EINVAL;
612         }
613
614         arg_remote_host = argv[optind];
615         return 1;
616 }
617
618 int main(int argc, char *argv[]) {
619         _cleanup_event_unref_ sd_event *event = NULL;
620         Context context = {};
621         int r, n, fd;
622
623         log_parse_environment();
624         log_open();
625
626         r = parse_argv(argc, argv);
627         if (r <= 0)
628                 goto finish;
629
630         r = sd_event_default(&event);
631         if (r < 0) {
632                 log_error("Failed to allocate event loop: %s", strerror(-r));
633                 goto finish;
634         }
635
636         sd_event_set_watchdog(event, true);
637
638         n = sd_listen_fds(1);
639         if (n < 0) {
640                 log_error("Failed to receive sockets from parent.");
641                 r = n;
642                 goto finish;
643         } else if (n == 0) {
644                 log_error("Didn't get any sockets passed in.");
645                 r = -EINVAL;
646                 goto finish;
647         }
648
649         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
650                 r = add_listen_socket(&context, event, fd);
651                 if (r < 0)
652                         goto finish;
653         }
654
655         r = sd_event_loop(event);
656         if (r < 0) {
657                 log_error("Failed to run event loop: %s", strerror(-r));
658                 goto finish;
659         }
660
661 finish:
662         context_free(&context);
663
664         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
665 }