chiark / gitweb /
util: replace close_nointr_nofail() by a more useful safe_close()
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40 #include "build.h"
41 #include "set.h"
42 #include "path-util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
48 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
49
50 typedef struct Context {
51         Set *listen;
52         Set *connections;
53 } Context;
54
55 typedef struct Connection {
56         Context *context;
57
58         int server_fd, client_fd;
59         int server_to_client_buffer[2]; /* a pipe */
60         int client_to_server_buffer[2]; /* a pipe */
61
62         size_t server_to_client_buffer_full, client_to_server_buffer_full;
63         size_t server_to_client_buffer_size, client_to_server_buffer_size;
64
65         sd_event_source *server_event_source, *client_event_source;
66 } Connection;
67
68 static const char *arg_remote_host = NULL;
69
70 static void connection_free(Connection *c) {
71         assert(c);
72
73         if (c->context)
74                 set_remove(c->context->connections, c);
75
76         sd_event_source_unref(c->server_event_source);
77         sd_event_source_unref(c->client_event_source);
78
79         safe_close(c->server_fd);
80         safe_close(c->client_fd);
81
82         close_pipe(c->server_to_client_buffer);
83         close_pipe(c->client_to_server_buffer);
84
85         free(c);
86 }
87
88 static void context_free(Context *context) {
89         sd_event_source *es;
90         Connection *c;
91
92         assert(context);
93
94         while ((es = set_steal_first(context->listen)))
95                 sd_event_source_unref(es);
96
97         while ((c = set_first(context->connections)))
98                 connection_free(c);
99
100         set_free(context->listen);
101         set_free(context->connections);
102 }
103
104 static int get_remote_sockaddr(union sockaddr_union *sa, socklen_t *salen) {
105         int r;
106
107         assert(sa);
108         assert(salen);
109
110         if (path_is_absolute(arg_remote_host)) {
111                 sa->un.sun_family = AF_UNIX;
112                 strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
113                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
114
115                 *salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa->un.sun_path);
116
117         } else if (arg_remote_host[0] == '@') {
118                 sa->un.sun_family = AF_UNIX;
119                 sa->un.sun_path[0] = 0;
120                 strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
121                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
122
123                 *salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
124
125         } else {
126                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
127                 const char *node, *service;
128
129                 struct addrinfo hints = {
130                         .ai_family = AF_UNSPEC,
131                         .ai_socktype = SOCK_STREAM,
132                         .ai_flags = AI_ADDRCONFIG
133                 };
134
135                 service = strrchr(arg_remote_host, ':');
136                 if (service) {
137                         node = strndupa(arg_remote_host, service - arg_remote_host);
138                         service ++;
139                 } else {
140                         node = arg_remote_host;
141                         service = "80";
142                 }
143
144                 log_debug("Looking up address info for %s:%s", node, service);
145                 r = getaddrinfo(node, service, &hints, &result);
146                 if (r != 0) {
147                         log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r));
148                         return -EHOSTUNREACH;
149                 }
150
151                 assert(result);
152                 if (result->ai_addrlen > sizeof(union sockaddr_union)) {
153                         log_error("Address too long.");
154                         return -E2BIG;
155                 }
156
157                 memcpy(sa, result->ai_addr, result->ai_addrlen);
158                 *salen = result->ai_addrlen;
159         }
160
161         return 0;
162 }
163
164 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
165         int r;
166
167         assert(c);
168         assert(buffer);
169         assert(sz);
170
171         if (buffer[0] >= 0)
172                 return 0;
173
174         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
175         if (r < 0) {
176                 log_error("Failed to allocate pipe buffer: %m");
177                 return -errno;
178         }
179
180         fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
181
182         r = fcntl(buffer[0], F_GETPIPE_SZ);
183         if (r < 0) {
184                 log_error("Failed to get pipe buffer size: %m");
185                 return -errno;
186         }
187
188         assert(r > 0);
189         *sz = r;
190
191         return 0;
192 }
193
194 static int connection_shovel(
195                 Connection *c,
196                 int *from, int buffer[2], int *to,
197                 size_t *full, size_t *sz,
198                 sd_event_source **from_source, sd_event_source **to_source) {
199
200         bool shoveled;
201
202         assert(c);
203         assert(from);
204         assert(buffer);
205         assert(buffer[0] >= 0);
206         assert(buffer[1] >= 0);
207         assert(to);
208         assert(full);
209         assert(sz);
210         assert(from_source);
211         assert(to_source);
212
213         do {
214                 ssize_t z;
215
216                 shoveled = false;
217
218                 if (*full < *sz && *from >= 0 && *to >= 0) {
219                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
220                         if (z > 0) {
221                                 *full += z;
222                                 shoveled = true;
223                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
224                                 *from_source = sd_event_source_unref(*from_source);
225                                 *from = safe_close(*from);
226                         } else if (errno != EAGAIN && errno != EINTR) {
227                                 log_error("Failed to splice: %m");
228                                 return -errno;
229                         }
230                 }
231
232                 if (*full > 0 && *to >= 0) {
233                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
234                         if (z > 0) {
235                                 *full -= z;
236                                 shoveled = true;
237                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
238                                 *to_source = sd_event_source_unref(*to_source);
239                                 *to = safe_close(*to);
240                         } else if (errno != EAGAIN && errno != EINTR) {
241                                 log_error("Failed to splice: %m");
242                                 return -errno;
243                         }
244                 }
245         } while (shoveled);
246
247         return 0;
248 }
249
250 static int connection_enable_event_sources(Connection *c, sd_event *event);
251
252 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
253         Connection *c = userdata;
254         int r;
255
256         assert(s);
257         assert(fd >= 0);
258         assert(c);
259
260         r = connection_shovel(c,
261                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
262                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
263                               &c->server_event_source, &c->client_event_source);
264         if (r < 0)
265                 goto quit;
266
267         r = connection_shovel(c,
268                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
269                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
270                               &c->client_event_source, &c->server_event_source);
271         if (r < 0)
272                 goto quit;
273
274         /* EOF on both sides? */
275         if (c->server_fd == -1 && c->client_fd == -1)
276                 goto quit;
277
278         /* Server closed, and all data written to client? */
279         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
280                 goto quit;
281
282         /* Client closed, and all data written to server? */
283         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
284                 goto quit;
285
286         r = connection_enable_event_sources(c, sd_event_source_get_event(s));
287         if (r < 0)
288                 goto quit;
289
290         return 1;
291
292 quit:
293         connection_free(c);
294         return 0; /* ignore errors, continue serving */
295 }
296
297 static int connection_enable_event_sources(Connection *c, sd_event *event) {
298         uint32_t a = 0, b = 0;
299         int r;
300
301         assert(c);
302         assert(event);
303
304         if (c->server_to_client_buffer_full > 0)
305                 b |= EPOLLOUT;
306         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
307                 a |= EPOLLIN;
308
309         if (c->client_to_server_buffer_full > 0)
310                 a |= EPOLLOUT;
311         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
312                 b |= EPOLLIN;
313
314         if (c->server_event_source)
315                 r = sd_event_source_set_io_events(c->server_event_source, a);
316         else if (c->server_fd >= 0)
317                 r = sd_event_add_io(event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
318         else
319                 r = 0;
320
321         if (r < 0) {
322                 log_error("Failed to set up server event source: %s", strerror(-r));
323                 return r;
324         }
325
326         if (c->client_event_source)
327                 r = sd_event_source_set_io_events(c->client_event_source, b);
328         else if (c->client_fd >= 0)
329                 r = sd_event_add_io(event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
330         else
331                 r = 0;
332
333         if (r < 0) {
334                 log_error("Failed to set up client event source: %s", strerror(-r));
335                 return r;
336         }
337
338         return 0;
339 }
340
341 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
342         Connection *c = userdata;
343         socklen_t solen;
344         int error, r;
345
346         assert(s);
347         assert(fd >= 0);
348         assert(c);
349
350         solen = sizeof(error);
351         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
352         if (r < 0) {
353                 log_error("Failed to issue SO_ERROR: %m");
354                 goto fail;
355         }
356
357         if (error != 0) {
358                 log_error("Failed to connect to remote host: %s", strerror(error));
359                 goto fail;
360         }
361
362         c->client_event_source = sd_event_source_unref(c->client_event_source);
363
364         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
365         if (r < 0)
366                 goto fail;
367
368         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
369         if (r < 0)
370                 goto fail;
371
372         r = connection_enable_event_sources(c, sd_event_source_get_event(s));
373         if (r < 0)
374                 goto fail;
375
376         return 0;
377
378 fail:
379         connection_free(c);
380         return 0; /* ignore errors, continue serving */
381 }
382
383 static int add_connection_socket(Context *context, sd_event *event, int fd) {
384         union sockaddr_union sa = {};
385         socklen_t salen;
386         Connection *c;
387         int r;
388
389         assert(context);
390         assert(event);
391         assert(fd >= 0);
392
393         if (set_size(context->connections) > CONNECTIONS_MAX) {
394                 log_warning("Hit connection limit, refusing connection.");
395                 safe_close(fd);
396                 return 0;
397         }
398
399         r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
400         if (r < 0)
401                 return log_oom();
402
403         c = new0(Connection, 1);
404         if (!c)
405                 return log_oom();
406
407         c->context = context;
408         c->server_fd = fd;
409         c->client_fd = -1;
410         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
411         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
412
413         r = set_put(context->connections, c);
414         if (r < 0) {
415                 free(c);
416                 return log_oom();
417         }
418
419         r = get_remote_sockaddr(&sa, &salen);
420         if (r < 0)
421                 goto fail;
422
423         c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
424         if (c->client_fd < 0) {
425                 log_error("Failed to get remote socket: %m");
426                 goto fail;
427         }
428
429         r = connect(c->client_fd, &sa.sa, salen);
430         if (r < 0) {
431                 if (errno == EINPROGRESS) {
432                         r = sd_event_add_io(event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
433                         if (r < 0) {
434                                 log_error("Failed to add connection socket: %s", strerror(-r));
435                                 goto fail;
436                         }
437
438                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
439                         if (r < 0) {
440                                 log_error("Failed to enable oneshot event source: %s", strerror(-r));
441                                 goto fail;
442                         }
443                 } else {
444                         log_error("Failed to connect to remote host: %m");
445                         goto fail;
446                 }
447         } else {
448                 r = connection_enable_event_sources(c, event);
449                 if (r < 0)
450                         goto fail;
451         }
452
453         return 0;
454
455 fail:
456         connection_free(c);
457         return 0; /* ignore non-OOM errors, continue serving */
458 }
459
460 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
461         _cleanup_free_ char *peer = NULL;
462         Context *context = userdata;
463         int nfd = -1, r;
464
465         assert(s);
466         assert(fd >= 0);
467         assert(revents & EPOLLIN);
468         assert(context);
469
470         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
471         if (nfd < 0) {
472                 if (errno != -EAGAIN)
473                         log_warning("Failed to accept() socket: %m");
474         } else {
475                 getpeername_pretty(nfd, &peer);
476                 log_debug("New connection from %s", strna(peer));
477
478                 r = add_connection_socket(context, sd_event_source_get_event(s), nfd);
479                 if (r < 0) {
480                         log_error("Failed to accept connection, ignoring: %s", strerror(-r));
481                         safe_close(fd);
482                 }
483         }
484
485         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
486         if (r < 0) {
487                 log_error("Error while re-enabling listener with ONESHOT: %s", strerror(-r));
488                 sd_event_exit(sd_event_source_get_event(s), r);
489                 return r;
490         }
491
492         return 1;
493 }
494
495 static int add_listen_socket(Context *context, sd_event *event, int fd) {
496         sd_event_source *source;
497         int r;
498
499         assert(context);
500         assert(event);
501         assert(fd >= 0);
502
503         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
504         if (r < 0) {
505                 log_oom();
506                 return r;
507         }
508
509         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
510         if (r < 0) {
511                 log_error("Failed to determine socket type: %s", strerror(-r));
512                 return r;
513         }
514         if (r == 0) {
515                 log_error("Passed in socket is not a stream socket.");
516                 return -EINVAL;
517         }
518
519         r = fd_nonblock(fd, true);
520         if (r < 0) {
521                 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
522                 return r;
523         }
524
525         r = sd_event_add_io(event, &source, fd, EPOLLIN, accept_cb, context);
526         if (r < 0) {
527                 log_error("Failed to add event source: %s", strerror(-r));
528                 return r;
529         }
530
531         r = set_put(context->listen, source);
532         if (r < 0) {
533                 log_error("Failed to add source to set: %s", strerror(-r));
534                 sd_event_source_unref(source);
535                 return r;
536         }
537
538         /* Set the watcher to oneshot in case other processes are also
539          * watching to accept(). */
540         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
541         if (r < 0) {
542                 log_error("Failed to enable oneshot mode: %s", strerror(-r));
543                 return r;
544         }
545
546         return 0;
547 }
548
549 static int help(void) {
550
551         printf("%s [HOST:PORT]\n"
552                "%s [SOCKET]\n\n"
553                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
554                "  -h --help              Show this help\n"
555                "     --version           Show package version\n",
556                program_invocation_short_name,
557                program_invocation_short_name);
558
559         return 0;
560 }
561
562 static int parse_argv(int argc, char *argv[]) {
563
564         enum {
565                 ARG_VERSION = 0x100,
566                 ARG_IGNORE_ENV
567         };
568
569         static const struct option options[] = {
570                 { "help",       no_argument, NULL, 'h'           },
571                 { "version",    no_argument, NULL, ARG_VERSION   },
572                 {}
573         };
574
575         int c;
576
577         assert(argc >= 0);
578         assert(argv);
579
580         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
581
582                 switch (c) {
583
584                 case 'h':
585                         return help();
586
587                 case ARG_VERSION:
588                         puts(PACKAGE_STRING);
589                         puts(SYSTEMD_FEATURES);
590                         return 0;
591
592                 case '?':
593                         return -EINVAL;
594
595                 default:
596                         assert_not_reached("Unhandled option");
597                 }
598         }
599
600         if (optind >= argc) {
601                 log_error("Not enough parameters.");
602                 return -EINVAL;
603         }
604
605         if (argc != optind+1) {
606                 log_error("Too many parameters.");
607                 return -EINVAL;
608         }
609
610         arg_remote_host = argv[optind];
611         return 1;
612 }
613
614 int main(int argc, char *argv[]) {
615         _cleanup_event_unref_ sd_event *event = NULL;
616         Context context = {};
617         int r, n, fd;
618
619         log_parse_environment();
620         log_open();
621
622         r = parse_argv(argc, argv);
623         if (r <= 0)
624                 goto finish;
625
626         r = sd_event_default(&event);
627         if (r < 0) {
628                 log_error("Failed to allocate event loop: %s", strerror(-r));
629                 goto finish;
630         }
631
632         sd_event_set_watchdog(event, true);
633
634         n = sd_listen_fds(1);
635         if (n < 0) {
636                 log_error("Failed to receive sockets from parent.");
637                 r = n;
638                 goto finish;
639         } else if (n == 0) {
640                 log_error("Didn't get any sockets passed in.");
641                 r = -EINVAL;
642                 goto finish;
643         }
644
645         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
646                 r = add_listen_socket(&context, event, fd);
647                 if (r < 0)
648                         goto finish;
649         }
650
651         r = sd_event_loop(event);
652         if (r < 0) {
653                 log_error("Failed to run event loop: %s", strerror(-r));
654                 goto finish;
655         }
656
657 finish:
658         context_free(&context);
659
660         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
661 }